Initial commit: HP Marketing Materials GraphRAG Chatbot
Full-stack GraphRAG chatbot for HP marketing materials with: - Python/Flask backend with custom ReAct agent (LlamaIndex) - Neo4j knowledge graph + vector search hybrid retrieval - LlamaParse multimodal document processing (text + images) - React/Vite frontend with conversation management - MongoDB conversation persistence - MSAL authentication support Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
594f749d4c
48 changed files with 17179 additions and 0 deletions
39
.gitignore
vendored
Normal file
39
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
# Environment variables (contain API keys and credentials)
|
||||
.env
|
||||
.env.*
|
||||
env
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
venv/
|
||||
.venv/
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Node
|
||||
chat-interface/node_modules/
|
||||
chat-interface/dist/
|
||||
|
||||
# Generated data (runtime artifacts, not source code)
|
||||
uploads/
|
||||
index_storage/
|
||||
*.log
|
||||
|
||||
# Large binary source documents (too large for git)
|
||||
supporting_files/
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Claude Code
|
||||
.claude/
|
||||
118
CLAUDE.md
Normal file
118
CLAUDE.md
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
# CLAUDE.md
|
||||
|
||||
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
|
||||
|
||||
## Project Overview
|
||||
|
||||
HP Marketing Materials Chatbot — a GraphRAG (Graph Retrieval-Augmented Generation) system that combines vector search with knowledge graph capabilities to answer questions about HP marketing materials and brand guidelines. Processes multimodal documents (text + images) via a custom ReAct agent.
|
||||
|
||||
## Development Commands
|
||||
|
||||
### Backend
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
python main.py # Starts Hypercorn ASGI server on localhost:8746
|
||||
```
|
||||
|
||||
### Frontend
|
||||
```bash
|
||||
cd chat-interface
|
||||
npm install
|
||||
npm run dev # Vite dev server
|
||||
npm run build # Production build to dist/
|
||||
npm run lint # ESLint
|
||||
```
|
||||
|
||||
### Required Services
|
||||
- **Neo4j**: Port 7688, credentials `neo4j/hp-graphrag-2024` (HP-dedicated instance; port 7687 is a separate Netflix project)
|
||||
- **MongoDB**: URI `mongodb://hp:hp@localhost:27017/?authSource=hp_chatbot`, database `hp_chatbot`
|
||||
|
||||
### Environment Variables
|
||||
Backend requires `.env` at project root with: `OPENAI_API_KEY`, `LLAMA_CLOUD_API_KEY`, `NEO4J_URL`, `NEO4J_USERNAME`, `NEO4J_PASSWORD`, `PORT` (default 8746).
|
||||
Frontend uses `chat-interface/.env` with: `VITE_BACKEND_URL`, `VITE_APP_BASE_URL`.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Request Flow
|
||||
1. Frontend (`App.jsx`) sends POST to `/chat` with `{message, sessionId}`
|
||||
2. `routes.py:chat()` maps session to conversation via `session_manager.py` and MongoDB
|
||||
3. The global `ReActAgent2` (from `shared_state.py`) processes the query
|
||||
4. Agent uses two tools: vector search (`answer_questions_from_hp_marketing_materials`) and GraphRAG hybrid search (`answerquestionswith_graphrag`)
|
||||
5. Response includes text, sources, reasoning steps, and image references
|
||||
6. Images are served via `/images/<filename>` from `uploads/images/`
|
||||
|
||||
### Shared State Pattern (Critical)
|
||||
All modules access the AI agent, vector index, and GraphRAG components through `shared_state.py` — a module with global variables and setter/getter functions. This avoids circular imports and ensures all modules reference the same instances. **Never import these globals directly from `ai_core.py`; always use `shared_state`.**
|
||||
|
||||
Key globals: `global_workflow_agent`, `global_index`, `global_graph_store`, `global_graphrag_query_engine`
|
||||
|
||||
### ReAct Agent (`ai_core.py`)
|
||||
`ReActAgent2` is a custom LlamaIndex `Workflow` subclass implementing a ReAct loop:
|
||||
- Steps: `new_user_msg` → `prepare_chat_history` → `handle_llm_input` → (tool calls via `handle_tool_calls` → loop back) → `StopEvent`
|
||||
- Has a `simple_run()` method monkey-patched onto the agent at initialization time (replaces the default workflow `run`)
|
||||
- Includes regex-based cleaning of LLM "thinking" artifacts from final responses
|
||||
- Timeouts: `AGENT_TIMEOUT` (600s overall), `LLM_TIMEOUT` (300s per call), `TOOL_EXECUTION_TIMEOUT` (300s per tool)
|
||||
|
||||
### GraphRAG System (`graph_rag_integration.py`)
|
||||
Three main classes:
|
||||
- **`GraphRAGExtractor`**: LlamaIndex `TransformComponent` that extracts entity-relation triplets from text nodes using LLM
|
||||
- **`GraphRAGStore`**: Wraps `Neo4jPropertyGraphStore`, adds community detection (tries graspologic → python-louvain → NetworkX fallback), caches community summaries to `index_storage/graphrag_cache/` as pickle files
|
||||
- **`GraphRAGQueryEngine`**: Combines vector retrieval with community-based graph retrieval, returning both contexts for synthesis
|
||||
|
||||
### Startup Sequence (`main.py`)
|
||||
1. MongoDB initialization (`init_mongodb.py`)
|
||||
2. `initialize_global_index()` in `ai_core.py`:
|
||||
- Configures LLM (chatgpt-4o-latest) and embeddings (text-embedding-3-small)
|
||||
- Loads existing vector index from `index_storage/hp_docs_index/` or builds new from `supporting_files/files_for_rag_store/`
|
||||
- Connects to Neo4j, creates/loads GraphRAG components
|
||||
- Builds communities (from cache or fresh)
|
||||
- Creates `ReActAgent2` and stores in `shared_state`
|
||||
|
||||
### Frontend (React + Vite)
|
||||
- Single-page app in `chat-interface/`, main component is `App.jsx`
|
||||
- Auth via MSAL (`auth.js`), username sent as `X-MS-USERNAME` header
|
||||
- Dev mode uses fallback `dev_user@local` username
|
||||
- Conversation sidebar with auto-width resizing
|
||||
- Markdown rendering via showdown, image viewer with pagination
|
||||
- Styling: TailwindCSS + Shadcn/ui + Radix tooltips
|
||||
|
||||
### JSON Serialization
|
||||
`json_utils.py` provides `CustomJSONEncoder` and `CustomJSONProvider` that handle LlamaIndex types (ToolOutput, ReasoningSteps, ChatMessage, etc.), BSON ObjectId, and datetime. Flask is configured to use this provider globally.
|
||||
|
||||
### Document Processing Pipeline
|
||||
Upload → LlamaParse (dual: text + images) → Semantic splitting (`SemanticSplitterNodeParser`) → Page-based image assignment to chunks → Dual indexing (vector store + Neo4j knowledge graph) → Community detection and caching
|
||||
|
||||
## API Endpoints
|
||||
|
||||
| Method | Path | Purpose |
|
||||
|--------|------|---------|
|
||||
| POST | `/chat` | Main chat endpoint |
|
||||
| GET | `/status` | System status (always returns initialized=true) |
|
||||
| GET | `/images/<filename>` | Serve document images |
|
||||
| GET | `/list-images` | List available images |
|
||||
| GET | `/conversations` | List user conversations |
|
||||
| GET | `/conversations/<id>/messages` | Get conversation messages |
|
||||
| POST | `/conversations/new` | Create new conversation |
|
||||
| DELETE | `/conversations/<id>` | Delete conversation (soft by default) |
|
||||
| POST | `/reset` | Reset global agent memory |
|
||||
| POST | `/download-brief` | Generate Word doc from markdown |
|
||||
| POST | `/capture-screenshot` | Manual LlamaParse image capture (dev only) |
|
||||
| GET | `/debug-status` | Debug endpoint (dev only) |
|
||||
| POST | `/reinitialize` | Force agent reinit (dev only) |
|
||||
|
||||
## Key Conventions
|
||||
|
||||
- Use `log_structured(level, message, data_dict)` from `utils.py` for all logging — it handles safe serialization of LlamaIndex objects
|
||||
- Session management maps frontend `sessionId` to MongoDB `conversation_id` via `session_manager.py` with an in-memory cache backed by MongoDB
|
||||
- The agent is a single global instance — all users share the same agent, but conversation history is loaded per-request from MongoDB
|
||||
- Community summaries use `gpt-4o-mini` for cost efficiency; main agent uses `chatgpt-4o-latest`
|
||||
|
||||
## Deployment
|
||||
|
||||
- Backend: Set `PRODUCTION=true` env var, deploys to `https://ai-sandbox.oliver.solutions/hp_chatbot_back`
|
||||
- Frontend: `npm run build`, deploy `dist/` to `/hp_chatbot/` path at `https://ai-sandbox.oliver.solutions/hp_chatbot/`
|
||||
- CORS origins configured in `config.py:CORS_ALLOWED_ORIGINS`
|
||||
|
||||
## Testing
|
||||
|
||||
No formal test suite. Manual testing: start backend (`python main.py`), start frontend (`npm run dev`), test chat + image responses + document citations.
|
||||
244
README.md
Normal file
244
README.md
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
# HP Marketing Materials Chatbot
|
||||
|
||||
A GraphRAG (Graph Retrieval-Augmented Generation) chatbot that answers questions about HP marketing materials and brand guidelines. Combines vector search with a Neo4j knowledge graph for more comprehensive retrieval, and processes multimodal documents (text + images) using LlamaParse.
|
||||
|
||||
## Features
|
||||
|
||||
- **Hybrid retrieval**: Vector search + knowledge graph community detection for richer context
|
||||
- **Multimodal document processing**: Extracts text and page images from PDFs via LlamaParse
|
||||
- **Custom ReAct agent**: LlamaIndex-based workflow with tool use, reasoning steps, and source citations
|
||||
- **Conversation persistence**: MongoDB-backed chat history with multi-conversation support
|
||||
- **Image references**: Responses include relevant document page screenshots
|
||||
- **Brief export**: Download conversation summaries as Word documents
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- **Python 3.10+**
|
||||
- **Node.js 18+**
|
||||
- **Neo4j** (dedicated instance on port 7688)
|
||||
- **MongoDB** (with authentication configured)
|
||||
- **API Keys**: OpenAI (`OPENAI_API_KEY`), LlamaCloud (`LLAMA_CLOUD_API_KEY`)
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Backend Setup
|
||||
|
||||
```bash
|
||||
# Create and activate a virtual environment
|
||||
python -m venv venv
|
||||
source venv/bin/activate # or venv\Scripts\activate on Windows
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Create .env file at project root
|
||||
cat > .env << 'EOF'
|
||||
OPENAI_API_KEY=your_openai_key
|
||||
LLAMA_CLOUD_API_KEY=your_llama_cloud_key
|
||||
NEO4J_URL=bolt://localhost:7688
|
||||
NEO4J_USERNAME=neo4j
|
||||
NEO4J_PASSWORD=hp-graphrag-2024
|
||||
PORT=8746
|
||||
PRODUCTION=false
|
||||
LOG_LEVEL=INFO
|
||||
EOF
|
||||
|
||||
# Start the server
|
||||
python main.py
|
||||
```
|
||||
|
||||
The backend runs on `http://localhost:8746`. On first startup it will:
|
||||
1. Initialize MongoDB collections and indexes
|
||||
2. Load or build the vector index from `supporting_files/files_for_rag_store/`
|
||||
3. Connect to Neo4j and build/load the knowledge graph
|
||||
4. Build community summaries (cached to `index_storage/graphrag_cache/`)
|
||||
|
||||
### 2. Frontend Setup
|
||||
|
||||
```bash
|
||||
cd chat-interface
|
||||
|
||||
# Install dependencies
|
||||
npm install
|
||||
|
||||
# Create .env file
|
||||
cat > .env << 'EOF'
|
||||
VITE_BACKEND_URL=http://localhost:8746
|
||||
VITE_APP_BASE_URL=/
|
||||
EOF
|
||||
|
||||
# Start dev server
|
||||
npm run dev
|
||||
```
|
||||
|
||||
The frontend runs on `http://localhost:5173`.
|
||||
|
||||
### 3. Database Setup
|
||||
|
||||
**Neo4j:**
|
||||
- Run a Neo4j instance on port 7688 (port 7687 is reserved for a separate project)
|
||||
- Credentials: `neo4j` / `hp-graphrag-2024`
|
||||
- The application auto-populates the graph on first index build
|
||||
|
||||
**MongoDB:**
|
||||
- Create a user `hp` with password `hp` and `authSource=hp_chatbot`
|
||||
- Database: `hp_chatbot`
|
||||
- Collections (`users`, `conversations`, `messages`) are auto-created by `init_mongodb.py` on startup
|
||||
|
||||
Example MongoDB user setup:
|
||||
```javascript
|
||||
use hp_chatbot
|
||||
db.createUser({
|
||||
user: "hp",
|
||||
pwd: "hp",
|
||||
roles: [{ role: "readWrite", db: "hp_chatbot" }]
|
||||
})
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
├── main.py # Entry point, Hypercorn ASGI server
|
||||
├── config.py # Centralized configuration
|
||||
├── ai_core.py # ReAct agent, document processing, index init
|
||||
├── graph_rag_integration.py # GraphRAG: extraction, community detection, query engine
|
||||
├── routes.py # Flask API endpoints
|
||||
├── shared_state.py # Global state for agent/index/graph (cross-module)
|
||||
├── session_manager.py # Session-to-conversation mapping
|
||||
├── mongodb_utils.py # MongoDB CRUD operations
|
||||
├── json_utils.py # Custom JSON serialization for LlamaIndex types
|
||||
├── document_generator.py # Markdown-to-Word document conversion
|
||||
├── utils.py # Logging and file utilities
|
||||
├── init_mongodb.py # Database initialization script
|
||||
├── requirements.txt # Python dependencies
|
||||
├── .env # Environment variables (not committed)
|
||||
├── supporting_files/
|
||||
│ └── files_for_rag_store/ # HP marketing documents for indexing
|
||||
├── uploads/
|
||||
│ └── images/ # Extracted document page images
|
||||
├── index_storage/
|
||||
│ ├── hp_docs_index/ # Persisted vector index
|
||||
│ └── graphrag_cache/ # Cached community summaries (pickle)
|
||||
└── chat-interface/ # React frontend
|
||||
├── src/
|
||||
│ ├── App.jsx # Main chat interface component
|
||||
│ ├── auth.js # MSAL authentication
|
||||
│ ├── components/
|
||||
│ │ ├── ChatInterface.jsx
|
||||
│ │ ├── ConversationManager.jsx
|
||||
│ │ └── ThemeToggle.jsx
|
||||
│ └── lib/utils.js
|
||||
├── package.json
|
||||
└── dist/ # Production build output
|
||||
```
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```
|
||||
┌─────────────┐ POST /chat ┌──────────────┐
|
||||
│ React UI │ ──────────────────► │ Flask/ │
|
||||
│ (App.jsx) │ ◄────────────────── │ Hypercorn │
|
||||
│ │ JSON response │ (routes.py) │
|
||||
└─────────────┘ └──────┬───────┘
|
||||
│
|
||||
┌──────▼───────┐
|
||||
│ Session Mgr │──── MongoDB
|
||||
│ │ (conversations,
|
||||
└──────┬───────┘ messages, users)
|
||||
│
|
||||
┌──────▼───────┐
|
||||
│ ReActAgent2 │
|
||||
│ (ai_core.py) │
|
||||
└──────┬───────┘
|
||||
│
|
||||
┌────────────┼────────────┐
|
||||
│ │
|
||||
┌──────▼──────┐ ┌──────▼───────┐
|
||||
│ Vector │ │ GraphRAG │
|
||||
│ Query Tool │ │ Query Tool │
|
||||
│ │ │ │
|
||||
│ LlamaIndex │ │ Vector + │
|
||||
│ VectorStore │ │ Community │
|
||||
│ Index │ │ Retrieval │
|
||||
└─────────────┘ └──────┬───────┘
|
||||
│
|
||||
┌──────▼───────┐
|
||||
│ Neo4j │
|
||||
│ Knowledge │
|
||||
│ Graph │
|
||||
└──────────────┘
|
||||
```
|
||||
|
||||
## API Reference
|
||||
|
||||
| Method | Endpoint | Description |
|
||||
|--------|----------|-------------|
|
||||
| `POST` | `/chat` | Send a chat message. Body: `{message, sessionId}` |
|
||||
| `GET` | `/status?sessionId=` | Check system initialization status |
|
||||
| `GET` | `/conversations` | List user's conversations (requires `X-MS-USERNAME` header) |
|
||||
| `POST` | `/conversations/new` | Create a new conversation |
|
||||
| `GET` | `/conversations/:id/messages` | Get messages for a conversation |
|
||||
| `DELETE` | `/conversations/:id` | Soft-delete a conversation |
|
||||
| `POST` | `/reset` | Reset global agent memory. Body: `{sessionId}` |
|
||||
| `GET` | `/images/:filename` | Serve a document page image |
|
||||
| `GET` | `/list-images` | List all available images |
|
||||
| `POST` | `/download-brief` | Generate Word doc. Body: `{brief_content, sessionId}` |
|
||||
|
||||
**Authentication**: The frontend sends the MSAL username via `X-MS-USERNAME` header. In development mode (`PRODUCTION=false`), a default `dev_user@local` is used.
|
||||
|
||||
## Configuration
|
||||
|
||||
All configuration is centralized in `config.py`. Key settings:
|
||||
|
||||
| Setting | Default | Description |
|
||||
|---------|---------|-------------|
|
||||
| `LLM_MODEL` | `chatgpt-4o-latest` | Main LLM for the ReAct agent |
|
||||
| `EMBEDDING_MODEL` | `text-embedding-3-small` | Embedding model for vector index |
|
||||
| `LLM_TEMPERATURE` | `0.3` | LLM temperature |
|
||||
| `SIMILARITY_TOP_K` | `10` | Number of vector results to retrieve |
|
||||
| `AGENT_TIMEOUT` | `600s` | Overall agent workflow timeout |
|
||||
| `LLM_TIMEOUT` | `300s` | Per-LLM-call timeout |
|
||||
| `SERVER_PORT` | `8746` | Backend server port |
|
||||
|
||||
Community summaries use `gpt-4o-mini` for cost efficiency (configured in `graph_rag_integration.py`).
|
||||
|
||||
## Adding Documents
|
||||
|
||||
Place HP marketing documents (PDF, DOCX, PPTX, TXT) in `supporting_files/files_for_rag_store/`. On the next startup with no existing index, the system will:
|
||||
|
||||
1. Parse documents with LlamaParse (text + image extraction)
|
||||
2. Split into semantic chunks
|
||||
3. Build a vector index (persisted to `index_storage/hp_docs_index/`)
|
||||
4. Extract knowledge graph triplets and store in Neo4j
|
||||
5. Run community detection and cache summaries
|
||||
|
||||
To force a full reindex, delete `index_storage/hp_docs_index/` and clear the Neo4j database before restarting.
|
||||
|
||||
## Deployment
|
||||
|
||||
### Backend
|
||||
- Set `PRODUCTION=true` environment variable
|
||||
- Server binds to `0.0.0.0` in production mode
|
||||
- Configure `CORS_ALLOWED_ORIGINS` in `config.py`
|
||||
- Production URL: `https://ai-sandbox.oliver.solutions/hp_chatbot_back`
|
||||
|
||||
### Frontend
|
||||
```bash
|
||||
cd chat-interface
|
||||
npm run build
|
||||
```
|
||||
- Deploy `dist/` contents to the `/hp_chatbot/` path
|
||||
- Ensure proper MIME types for `.js` files on the web server
|
||||
- Configure SPA routing (see `web.config` or `.htaccess`)
|
||||
- Production URL: `https://ai-sandbox.oliver.solutions/hp_chatbot/`
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
| Issue | Solution |
|
||||
|-------|----------|
|
||||
| Backend won't start | Check that Neo4j and MongoDB are running. Verify `OPENAI_API_KEY` is set in `.env` |
|
||||
| "Agent unavailable" errors | Check startup logs for LLM API test failure. The `/reinitialize` endpoint (dev only) can force re-init |
|
||||
| No images in responses | Verify `LLAMA_CLOUD_API_KEY` is set. Check that `uploads/images/` contains extracted images |
|
||||
| CORS errors | Add the frontend origin to `CORS_ALLOWED_ORIGINS` in `config.py` |
|
||||
| Slow first startup | Initial document processing and graph building can take significant time depending on document volume |
|
||||
| Neo4j connection refused | Ensure Neo4j is on port 7688 (not 7687, which is a different project) |
|
||||
1374
ai_core.py
Normal file
1374
ai_core.py
Normal file
File diff suppressed because it is too large
Load diff
24
chat-interface/.gitignore
vendored
Normal file
24
chat-interface/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
# Logs
|
||||
logs
|
||||
*.log
|
||||
npm-debug.log*
|
||||
yarn-debug.log*
|
||||
yarn-error.log*
|
||||
pnpm-debug.log*
|
||||
lerna-debug.log*
|
||||
|
||||
node_modules
|
||||
dist
|
||||
dist-ssr
|
||||
*.local
|
||||
|
||||
# Editor directories and files
|
||||
.vscode/*
|
||||
!.vscode/extensions.json
|
||||
.idea
|
||||
.DS_Store
|
||||
*.suo
|
||||
*.ntvs*
|
||||
*.njsproj
|
||||
*.sln
|
||||
*.sw?
|
||||
37
chat-interface/.htaccess
Normal file
37
chat-interface/.htaccess
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
# JavaScript MIME type
|
||||
AddType application/javascript .js
|
||||
AddType application/json .json
|
||||
|
||||
# CSS MIME type
|
||||
AddType text/css .css
|
||||
|
||||
# Image MIME types
|
||||
AddType image/svg+xml .svg
|
||||
AddType image/png .png
|
||||
AddType image/jpeg .jpg
|
||||
AddType image/jpeg .jpeg
|
||||
AddType image/gif .gif
|
||||
AddType image/webp .webp
|
||||
|
||||
# Font MIME types
|
||||
AddType font/ttf .ttf
|
||||
AddType font/otf .otf
|
||||
AddType font/woff .woff
|
||||
AddType font/woff2 .woff2
|
||||
|
||||
# Force JavaScript MIME type for all JS files
|
||||
<Files *.js>
|
||||
ForceType application/javascript
|
||||
</Files>
|
||||
|
||||
# Enable mod_rewrite
|
||||
RewriteEngine On
|
||||
RewriteBase /netflix_chatbot/
|
||||
|
||||
# Don't rewrite files or directories
|
||||
RewriteCond %{REQUEST_FILENAME} -f [OR]
|
||||
RewriteCond %{REQUEST_FILENAME} -d
|
||||
RewriteRule ^ - [L]
|
||||
|
||||
# Rewrite everything else to index.html
|
||||
RewriteRule ^ index.html [L]
|
||||
50
chat-interface/DEPLOY.md
Normal file
50
chat-interface/DEPLOY.md
Normal file
|
|
@ -0,0 +1,50 @@
|
|||
# Deployment Instructions
|
||||
|
||||
## Files to Deploy
|
||||
Deploy all files from the `dist` directory to your server at the path `/hp_chatbot/`.
|
||||
|
||||
## Changing the Backend URL
|
||||
If you need to change the backend API URL:
|
||||
|
||||
1. Edit the `.env` and `.env.production` files to update the `VITE_BACKEND_URL` value
|
||||
2. Or use the provided script: `./update-backend.sh` which will update the URL and rebuild the application
|
||||
3. Then rebuild the application with `npm run build`
|
||||
|
||||
## Important Server Configuration
|
||||
The application requires proper MIME type configuration to work correctly. Depending on your server type, use one of the following:
|
||||
|
||||
### Apache Server
|
||||
Make sure the `.htaccess` file is included in your deployment. It contains:
|
||||
- MIME type configurations
|
||||
- URL rewrite rules for SPA routing
|
||||
- CORS headers
|
||||
|
||||
### IIS Server
|
||||
Make sure the `web.config` file is included in your deployment. It contains:
|
||||
- MIME type configurations
|
||||
- URL rewrite rules
|
||||
- CORS headers
|
||||
|
||||
## Common Issues and Solutions
|
||||
|
||||
### JavaScript MIME Type Error
|
||||
If you get an error like:
|
||||
```
|
||||
Loading module from "https://ai-sandbox.oliver.solutions/hp_chatbot/assets/index-XXXXX.js" was blocked because of a disallowed MIME type ("text/html").
|
||||
```
|
||||
|
||||
Check that:
|
||||
1. Your server is serving .js files with the correct MIME type: `application/javascript`
|
||||
2. The `.htaccess` or `web.config` file is properly uploaded and enabled
|
||||
3. Your server allows URL rewriting and custom MIME types
|
||||
|
||||
### Authentication Issues
|
||||
- Make sure the Microsoft authentication is set up correctly
|
||||
- The redirectUri in the MSAL configuration should match your deployment URL
|
||||
- Check that cookies and localStorage are enabled in the browser
|
||||
|
||||
### API Connection Issues
|
||||
- The app expects the backend API to be accessible at the configured URL (default: `https://ai-sandbox.oliver.solutions/hp_chatbot_back`)
|
||||
- If you need to change the backend API URL, follow the instructions in "Changing the Backend URL" section above
|
||||
- In development, API calls are proxied through Vite's development server
|
||||
- In production, API calls go directly to the configured backend URL
|
||||
72
chat-interface/README.md
Normal file
72
chat-interface/README.md
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
# HP Marketing Materials Chatbot
|
||||
|
||||
A React frontend for the HP Marketing Materials Chatbot, providing a chat interface to query the HP marketing knowledge base.
|
||||
|
||||
## Features
|
||||
|
||||
- Clean chat interface for asking questions about HP marketing materials
|
||||
- Sources and reasoning display for transparency
|
||||
- Session-based memory for contextual conversations
|
||||
- Conversation management system
|
||||
|
||||
## Development
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Node.js 18+
|
||||
- npm or yarn
|
||||
|
||||
### Setup
|
||||
|
||||
1. Clone the repository
|
||||
2. Install dependencies:
|
||||
```bash
|
||||
npm install
|
||||
```
|
||||
|
||||
3. Start the development server:
|
||||
```bash
|
||||
npm run dev
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
The application uses environment variables for configuration. Create a `.env` file in the root directory with the following variables:
|
||||
|
||||
```
|
||||
# Backend API URL
|
||||
VITE_BACKEND_URL=https://ai-sandbox.oliver.solutions/hp_chatbot_back
|
||||
|
||||
# Base URL for the app (changes in production)
|
||||
VITE_APP_BASE_URL=/
|
||||
```
|
||||
|
||||
### Changing the Backend URL
|
||||
|
||||
If you need to change the backend API URL:
|
||||
|
||||
1. Edit the `.env` and `.env.production` files to update the `VITE_BACKEND_URL` value
|
||||
2. Or use the provided script: `./update-backend.sh` which will update the URL and rebuild the application
|
||||
3. Then rebuild the application with `npm run build`
|
||||
|
||||
## Building for Production
|
||||
|
||||
To create a production build:
|
||||
|
||||
```bash
|
||||
npm run build
|
||||
```
|
||||
|
||||
The output will be in the `dist` directory, ready for deployment.
|
||||
|
||||
## Deployment
|
||||
|
||||
See [DEPLOY.md](./DEPLOY.md) for detailed deployment instructions.
|
||||
|
||||
## Technologies Used
|
||||
|
||||
- React 18
|
||||
- Vite
|
||||
- TailwindCSS
|
||||
- Microsoft Authentication Library (MSAL) or custom authentication
|
||||
- Shadcn/ui components
|
||||
17
chat-interface/components.json
Normal file
17
chat-interface/components.json
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
{
|
||||
"$schema": "https://ui.shadcn.com/schema.json",
|
||||
"style": "default",
|
||||
"rsc": false,
|
||||
"tsx": false,
|
||||
"tailwind": {
|
||||
"config": "tailwind.config.js",
|
||||
"css": "src/index.css",
|
||||
"baseColor": "slate",
|
||||
"cssVariables": true
|
||||
},
|
||||
"aliases": {
|
||||
"components": "@/components",
|
||||
"utils": "@/lib/utils"
|
||||
}
|
||||
}
|
||||
|
||||
38
chat-interface/eslint.config.js
Normal file
38
chat-interface/eslint.config.js
Normal file
|
|
@ -0,0 +1,38 @@
|
|||
import js from '@eslint/js'
|
||||
import globals from 'globals'
|
||||
import react from 'eslint-plugin-react'
|
||||
import reactHooks from 'eslint-plugin-react-hooks'
|
||||
import reactRefresh from 'eslint-plugin-react-refresh'
|
||||
|
||||
export default [
|
||||
{ ignores: ['dist'] },
|
||||
{
|
||||
files: ['**/*.{js,jsx}'],
|
||||
languageOptions: {
|
||||
ecmaVersion: 2020,
|
||||
globals: globals.browser,
|
||||
parserOptions: {
|
||||
ecmaVersion: 'latest',
|
||||
ecmaFeatures: { jsx: true },
|
||||
sourceType: 'module',
|
||||
},
|
||||
},
|
||||
settings: { react: { version: '18.3' } },
|
||||
plugins: {
|
||||
react,
|
||||
'react-hooks': reactHooks,
|
||||
'react-refresh': reactRefresh,
|
||||
},
|
||||
rules: {
|
||||
...js.configs.recommended.rules,
|
||||
...react.configs.recommended.rules,
|
||||
...react.configs['jsx-runtime'].rules,
|
||||
...reactHooks.configs.recommended.rules,
|
||||
'react/jsx-no-target-blank': 'off',
|
||||
'react-refresh/only-export-components': [
|
||||
'warn',
|
||||
{ allowConstantExport: true },
|
||||
],
|
||||
},
|
||||
},
|
||||
]
|
||||
31
chat-interface/index.html
Normal file
31
chat-interface/index.html
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<!-- Helps with MIME type issues -->
|
||||
<meta http-equiv="X-Content-Type-Options" content="nosniff" />
|
||||
<title>HP Marketing Materials Chatbot</title>
|
||||
<!-- MSAL Authentication -->
|
||||
<script src="https://alcdn.msauth.net/browser/2.15.0/js/msal-browser.min.js" crossorigin="anonymous"></script>
|
||||
<style>
|
||||
#protected-content {
|
||||
display: none;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="login-container" class="flex items-center justify-center h-screen">
|
||||
<div class="text-center">
|
||||
<h1 class="text-2xl font-bold mb-4">HP Marketing Materials Chatbot</h1>
|
||||
<p class="mb-4">Please sign in to access the chatbot.</p>
|
||||
<button id="signin-button" class="px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600">
|
||||
Sign in with Microsoft
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div id="root" style="display: none;"></div>
|
||||
<script type="module" src="/src/main.jsx"></script>
|
||||
</body>
|
||||
</html>
|
||||
6784
chat-interface/package-lock.json
generated
Normal file
6784
chat-interface/package-lock.json
generated
Normal file
File diff suppressed because it is too large
Load diff
40
chat-interface/package.json
Normal file
40
chat-interface/package.json
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
{
|
||||
"name": "hp-chat-interface",
|
||||
"private": true,
|
||||
"version": "0.0.0",
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"dev": "vite",
|
||||
"build": "vite build",
|
||||
"lint": "eslint .",
|
||||
"preview": "vite preview"
|
||||
},
|
||||
"dependencies": {
|
||||
"@radix-ui/react-slot": "^1.1.0",
|
||||
"@radix-ui/react-tooltip": "^1.1.4",
|
||||
"react": "^18.3.1",
|
||||
"react-dom": "^18.3.1",
|
||||
"showdown": "^2.1.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@eslint/js": "^9.13.0",
|
||||
"@shadcn/ui": "^0.0.4",
|
||||
"@types/react": "^18.3.12",
|
||||
"@types/react-dom": "^18.3.1",
|
||||
"@vitejs/plugin-react": "^4.3.3",
|
||||
"autoprefixer": "^10.4.20",
|
||||
"class-variance-authority": "^0.7.0",
|
||||
"clsx": "^2.1.1",
|
||||
"eslint": "^9.13.0",
|
||||
"eslint-plugin-react": "^7.37.2",
|
||||
"eslint-plugin-react-hooks": "^5.0.0",
|
||||
"eslint-plugin-react-refresh": "^0.4.14",
|
||||
"globals": "^15.11.0",
|
||||
"lucide-react": "^0.455.0",
|
||||
"postcss": "^8.4.47",
|
||||
"tailwind-merge": "^2.5.4",
|
||||
"tailwindcss": "^3.4.14",
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"vite": "^5.4.10"
|
||||
}
|
||||
}
|
||||
6
chat-interface/postcss.config.js
Normal file
6
chat-interface/postcss.config.js
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
export default {
|
||||
plugins: {
|
||||
tailwindcss: {},
|
||||
autoprefixer: {},
|
||||
},
|
||||
}
|
||||
BIN
chat-interface/public/images/Netflix_2015_N_logo.png
Normal file
BIN
chat-interface/public/images/Netflix_2015_N_logo.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 86 KiB |
3
chat-interface/public/images/netflix-logo.svg
Normal file
3
chat-interface/public/images/netflix-logo.svg
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1024 276.742">
|
||||
<path d="M140.803 258.904c-15.404 2.705-31.079 3.516-47.294 5.676l-49.458-144.856v151.073c-15.404 1.621-29.457 3.783-44.051 5.945v-276.742h41.08l56.212 157.021v-157.021h43.511v258.904zm85.131-157.558c16.757 0 42.431-.811 57.835-.811v43.24c-19.189 0-41.619 0-57.835.811v64.322c25.405-1.621 50.809-3.785 76.482-4.596v41.617l-119.724 9.461v-255.39h119.724v43.241h-76.482v58.105zm237.284-58.104h-44.862v198.908c-14.594 0-29.188 0-43.239.539v-199.447h-44.862v-43.242h132.965l-.002 43.242zm70.266 55.132h59.187v43.24h-59.187v98.104h-42.433v-239.718h120.808v43.241h-78.375v55.133zm148.641 103.507c24.594.539 49.456 2.434 73.51 3.783v42.701c-38.646-2.434-77.293-4.863-116.75-5.676v-242.689h43.24v201.881zm109.994 49.457c13.783.812 28.377 1.623 42.43 3.242v-254.58h-42.43v251.338zm231.881-251.338l-54.863 131.615 54.863 145.127c-16.217-2.162-32.432-5.135-48.648-7.838l-31.078-79.994-31.617 73.51c-15.678-2.705-30.812-3.516-46.484-5.678l55.672-126.75-50.269-129.992h46.482l28.377 74.59 30.27-74.59h47.295z" fill="#d81f26"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 1.1 KiB |
1
chat-interface/public/vite.svg
Normal file
1
chat-interface/public/vite.svg
Normal file
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="31.88" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 257"><defs><linearGradient id="IconifyId1813088fe1fbc01fb466" x1="-.828%" x2="57.636%" y1="7.652%" y2="78.411%"><stop offset="0%" stop-color="#41D1FF"></stop><stop offset="100%" stop-color="#BD34FE"></stop></linearGradient><linearGradient id="IconifyId1813088fe1fbc01fb467" x1="43.376%" x2="50.316%" y1="2.242%" y2="89.03%"><stop offset="0%" stop-color="#FFEA83"></stop><stop offset="8.333%" stop-color="#FFDD35"></stop><stop offset="100%" stop-color="#FFA800"></stop></linearGradient></defs><path fill="url(#IconifyId1813088fe1fbc01fb466)" d="M255.153 37.938L134.897 252.976c-2.483 4.44-8.862 4.466-11.382.048L.875 37.958c-2.746-4.814 1.371-10.646 6.827-9.67l120.385 21.517a6.537 6.537 0 0 0 2.322-.004l117.867-21.483c5.438-.991 9.574 4.796 6.877 9.62Z"></path><path fill="url(#IconifyId1813088fe1fbc01fb467)" d="M185.432.063L96.44 17.501a3.268 3.268 0 0 0-2.634 3.014l-5.474 92.456a3.268 3.268 0 0 0 3.997 3.378l24.777-5.718c2.318-.535 4.413 1.507 3.936 3.838l-7.361 36.047c-.495 2.426 1.782 4.5 4.151 3.78l15.304-4.649c2.372-.72 4.652 1.36 4.15 3.788l-11.698 56.621c-.732 3.542 3.979 5.473 5.943 2.437l1.313-2.028l72.516-144.72c1.215-2.423-.88-5.186-3.54-4.672l-25.505 4.922c-2.396.462-4.435-1.77-3.759-4.114l16.646-57.705c.677-2.35-1.37-4.583-3.769-4.113Z"></path></svg>
|
||||
|
After Width: | Height: | Size: 1.5 KiB |
42
chat-interface/src/App.css
Normal file
42
chat-interface/src/App.css
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
#root {
|
||||
max-width: 1280px;
|
||||
margin: 0 auto;
|
||||
padding: 2rem;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.logo {
|
||||
height: 6em;
|
||||
padding: 1.5em;
|
||||
will-change: filter;
|
||||
transition: filter 300ms;
|
||||
}
|
||||
.logo:hover {
|
||||
filter: drop-shadow(0 0 2em #646cffaa);
|
||||
}
|
||||
.logo.react:hover {
|
||||
filter: drop-shadow(0 0 2em #61dafbaa);
|
||||
}
|
||||
|
||||
@keyframes logo-spin {
|
||||
from {
|
||||
transform: rotate(0deg);
|
||||
}
|
||||
to {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
@media (prefers-reduced-motion: no-preference) {
|
||||
a:nth-of-type(2) .logo {
|
||||
animation: logo-spin infinite 20s linear;
|
||||
}
|
||||
}
|
||||
|
||||
.card {
|
||||
padding: 2em;
|
||||
}
|
||||
|
||||
.read-the-docs {
|
||||
color: #888;
|
||||
}
|
||||
1525
chat-interface/src/App.jsx
Normal file
1525
chat-interface/src/App.jsx
Normal file
File diff suppressed because it is too large
Load diff
691
chat-interface/src/App_11-19.jsx
Normal file
691
chat-interface/src/App_11-19.jsx
Normal file
|
|
@ -0,0 +1,691 @@
|
|||
import { useState, useRef, useEffect } from 'react';
|
||||
import { Send, Upload, Loader2, X, FileText, Info, FileDown } from 'lucide-react';
|
||||
import { Alert, AlertDescription } from "./components/ui/alert";
|
||||
import * as Tooltip from '@radix-ui/react-tooltip';
|
||||
|
||||
const BACKEND_URL = 'http://localhost:6173';
|
||||
|
||||
const fetchDefaults = {
|
||||
mode: 'cors',
|
||||
credentials: 'include',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'Accept': '*/*',
|
||||
'Access-Control-Allow-Origin': '*',
|
||||
}
|
||||
};
|
||||
|
||||
export default function ChatInterface() {
|
||||
const [messages, setMessages] = useState([]);
|
||||
const [inputMessage, setInputMessage] = useState('');
|
||||
const [isProcessing, setIsProcessing] = useState(false);
|
||||
const [briefFiles, setBriefFiles] = useState([]);
|
||||
const [supportingFiles, setSupportingFiles] = useState([]);
|
||||
const [isInitialized, setIsInitialized] = useState(false);
|
||||
const [error, setError] = useState(null);
|
||||
|
||||
const messagesEndRef = useRef(null);
|
||||
const fileInputBriefRef = useRef(null);
|
||||
const fileInputSupportingRef = useRef(null);
|
||||
|
||||
useEffect(() => {
|
||||
// Enable CORS for all fetch requests
|
||||
fetch.defaults = fetchDefaults;
|
||||
}, []);
|
||||
|
||||
const logToConsole = (type, message, data = null) => {
|
||||
const timestamp = new Date().toISOString();
|
||||
const logMessage = {
|
||||
timestamp,
|
||||
type,
|
||||
message,
|
||||
data
|
||||
};
|
||||
console.log(JSON.stringify(logMessage, null, 2));
|
||||
};
|
||||
|
||||
const scrollToBottom = () => {
|
||||
messagesEndRef.current?.scrollIntoView({ behavior: "smooth" });
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
scrollToBottom();
|
||||
}, [messages]);
|
||||
|
||||
const handleFilesChange = (e, type) => {
|
||||
const files = Array.from(e.target.files);
|
||||
logToConsole('info', `Handling ${type} files change`, {
|
||||
fileCount: files.length,
|
||||
fileDetails: files.map(f => ({
|
||||
name: f.name,
|
||||
type: f.type,
|
||||
size: f.size
|
||||
}))
|
||||
});
|
||||
|
||||
if (type === 'brief') {
|
||||
setBriefFiles(prev => [...prev, ...files]);
|
||||
} else {
|
||||
setSupportingFiles(prev => [...prev, ...files]);
|
||||
}
|
||||
};
|
||||
|
||||
const removeFile = (fileName, type) => {
|
||||
logToConsole('info', `Removing file`, {
|
||||
fileName,
|
||||
type
|
||||
});
|
||||
|
||||
if (type === 'brief') {
|
||||
setBriefFiles(prev => prev.filter(file => file.name !== fileName));
|
||||
} else {
|
||||
setSupportingFiles(prev => prev.filter(file => file.name !== fileName));
|
||||
}
|
||||
};
|
||||
|
||||
const MessageBubble = ({ message }) => {
|
||||
// Helper function to safely extract source content
|
||||
const formatSources = (sources) => {
|
||||
if (!sources || !Array.isArray(sources)) return [];
|
||||
|
||||
return sources.map((source) => {
|
||||
// Handle the nested structure from the backend
|
||||
if (source && typeof source === 'object') {
|
||||
// Check for the different possible source structures
|
||||
if (source.content) {
|
||||
return {
|
||||
text: typeof source.content === 'string'
|
||||
? source.content
|
||||
: JSON.stringify(source.content),
|
||||
tool: source.tool_name || ''
|
||||
};
|
||||
}
|
||||
// Fallback for other source structures
|
||||
return {
|
||||
text: JSON.stringify(source),
|
||||
tool: ''
|
||||
};
|
||||
}
|
||||
// Handle string sources
|
||||
return {
|
||||
text: String(source),
|
||||
tool: ''
|
||||
};
|
||||
}).filter(source => source.text); // Remove empty sources
|
||||
};
|
||||
|
||||
// Helper function to format reasoning steps
|
||||
const formatReasoning = (reasoning) => {
|
||||
if (!reasoning || !Array.isArray(reasoning)) return [];
|
||||
|
||||
return reasoning.map((step, index) => {
|
||||
// Handle the new reasoning step structure
|
||||
const stepType = step.type || 'thought';
|
||||
let content = '';
|
||||
|
||||
if (step.action) {
|
||||
content = `${step.action}${step.action_input ? `: ${step.action_input}` : ''}`;
|
||||
} else if (step.observation) {
|
||||
content = step.observation;
|
||||
} else if (step.response) {
|
||||
content = step.response;
|
||||
} else if (step.thought) {
|
||||
content = step.thought;
|
||||
}
|
||||
|
||||
return {
|
||||
id: index,
|
||||
type: stepType,
|
||||
content: content,
|
||||
thought: step.thought
|
||||
};
|
||||
}).filter(step => step.content);
|
||||
};
|
||||
|
||||
const sources = formatSources(message.sources);
|
||||
const reasoningSteps = formatReasoning(message.reasoning);
|
||||
|
||||
const handleDownloadBrief = async () => {
|
||||
// First, send the message to generate the brief
|
||||
const briefRequestMessage = "write a comprehensive, detailed, and organized marketing brief based on the entire history of this conversation";
|
||||
setMessages(prev => [...prev, { role: 'user', content: briefRequestMessage }]);
|
||||
setIsProcessing(true);
|
||||
|
||||
try {
|
||||
// First request to generate the brief content
|
||||
const response = await fetch(`${BACKEND_URL}/chat`, {
|
||||
...fetchDefaults,
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ message: briefRequestMessage })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to generate brief');
|
||||
}
|
||||
|
||||
const responseData = await response.json();
|
||||
|
||||
// Add the assistant's response to the chat
|
||||
setMessages(prev => [...prev, {
|
||||
role: 'assistant',
|
||||
content: responseData.data.response,
|
||||
sources: responseData.data.sources,
|
||||
reasoning: responseData.data.reasoning
|
||||
}]);
|
||||
|
||||
// Then make a second request to download the document
|
||||
const downloadResponse = await fetch(`${BACKEND_URL}/download-brief`, {
|
||||
...fetchDefaults,
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ brief_content: responseData.data.response })
|
||||
});
|
||||
|
||||
if (!downloadResponse.ok) {
|
||||
throw new Error('Failed to download brief');
|
||||
}
|
||||
|
||||
// Get the blob from the response
|
||||
const blob = await downloadResponse.blob();
|
||||
|
||||
// Create a download link and trigger it
|
||||
const downloadUrl = window.URL.createObjectURL(blob);
|
||||
const link = document.createElement('a');
|
||||
link.href = downloadUrl;
|
||||
link.download = 'marketing_brief.docx';
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
window.URL.revokeObjectURL(downloadUrl);
|
||||
|
||||
} catch (e) {
|
||||
console.error('Error in download brief:', e);
|
||||
setError('Failed to download brief. Please try again.');
|
||||
} finally {
|
||||
setIsProcessing(false);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
return (
|
||||
<Tooltip.Provider>
|
||||
<div className="flex flex-col h-screen max-w-4xl mx-auto p-4">
|
||||
{!isInitialized ? (
|
||||
// ... (initialization UI remains the same)
|
||||
) : (
|
||||
<>
|
||||
<div className="flex-1 overflow-y-auto space-y-4 mb-4">
|
||||
{messages.map((message, index) => (
|
||||
<MessageBubble key={index} message={message} />
|
||||
))}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
|
||||
<div className="flex space-x-2">
|
||||
<input
|
||||
type="text"
|
||||
value={inputMessage}
|
||||
onChange={(e) => setInputMessage(e.target.value)}
|
||||
onKeyPress={(e) => e.key === 'Enter' && handleSubmit()}
|
||||
placeholder="Type your message..."
|
||||
className="flex-1 p-2 border rounded focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||
disabled={isProcessing}
|
||||
/>
|
||||
<button
|
||||
onClick={handleSubmit}
|
||||
disabled={isProcessing || !inputMessage.trim()}
|
||||
className="px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 disabled:bg-gray-300 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isProcessing ? (
|
||||
<Loader2 className="animate-spin" size={16} />
|
||||
) : (
|
||||
<Send size={16} />
|
||||
)}
|
||||
</button>
|
||||
<button
|
||||
onClick={handleDownloadBrief}
|
||||
disabled={isProcessing}
|
||||
className="px-4 py-2 bg-green-500 text-white rounded hover:bg-green-600 disabled:bg-gray-300 disabled:cursor-not-allowed flex items-center space-x-2"
|
||||
>
|
||||
{isProcessing ? (
|
||||
<Loader2 className="animate-spin" size={16} />
|
||||
) : (
|
||||
<>
|
||||
<FileDown size={16} />
|
||||
<span>Download Brief</span>
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<Alert variant="destructive" className="mt-4">
|
||||
<AlertDescription>{error}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
</div>
|
||||
</Tooltip.Provider>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
// return (
|
||||
// <div className={`flex ${message.role === 'user' ? 'justify-end' : 'justify-start'} mb-4`}>
|
||||
// <div className={`max-w-[80%] rounded-lg p-3 ${
|
||||
// message.role === 'user' ? 'bg-blue-500 text-white' : 'bg-gray-100'
|
||||
// }`}>
|
||||
// <div className="mb-2 whitespace-pre-wrap">{message.content}</div>
|
||||
|
||||
// <div className="flex gap-2 mt-2">
|
||||
// {sources.length > 0 && (
|
||||
// <Tooltip.Provider delayDuration={200}>
|
||||
// <Tooltip.Root>
|
||||
// <Tooltip.Trigger asChild>
|
||||
// <button
|
||||
// className={`flex items-center gap-1 text-xs rounded px-2 py-1 ${
|
||||
// message.role === 'user'
|
||||
// ? 'bg-white/10 hover:bg-white/20'
|
||||
// : 'bg-gray-200 hover:bg-gray-300'
|
||||
// }`}
|
||||
// >
|
||||
// <Info size={12} />
|
||||
// <span>Sources ({sources.length})</span>
|
||||
// </button>
|
||||
// </Tooltip.Trigger>
|
||||
// <Tooltip.Portal>
|
||||
// <Tooltip.Content
|
||||
// className="bg-white p-4 rounded-lg shadow-lg border border-gray-200 max-w-md z-50"
|
||||
// sideOffset={5}
|
||||
// >
|
||||
// <div className="max-h-[300px] overflow-y-auto">
|
||||
// <h4 className="font-semibold mb-2 text-gray-900">Sources Used:</h4>
|
||||
// <ul className="space-y-3">
|
||||
// {sources.map((source, idx) => (
|
||||
// <li key={idx} className="text-sm">
|
||||
// {source.tool && (
|
||||
// <div className="text-xs font-medium text-gray-500 mb-1">
|
||||
// Tool: {source.tool}
|
||||
// </div>
|
||||
// )}
|
||||
// <div className="text-gray-700">{source.text}</div>
|
||||
// </li>
|
||||
// ))}
|
||||
// </ul>
|
||||
// </div>
|
||||
// <Tooltip.Arrow className="fill-white" />
|
||||
// </Tooltip.Content>
|
||||
// </Tooltip.Portal>
|
||||
// </Tooltip.Root>
|
||||
// </Tooltip.Provider>
|
||||
// )}
|
||||
|
||||
// {reasoningSteps.length > 0 && (
|
||||
// <Tooltip.Provider delayDuration={200}>
|
||||
// <Tooltip.Root>
|
||||
// <Tooltip.Trigger asChild>
|
||||
// <button
|
||||
// className={`flex items-center gap-1 text-xs rounded px-2 py-1 ${
|
||||
// message.role === 'user'
|
||||
// ? 'bg-white/10 hover:bg-white/20'
|
||||
// : 'bg-gray-200 hover:bg-gray-300'
|
||||
// }`}
|
||||
// >
|
||||
// <Info size={12} />
|
||||
// <span>Reasoning ({reasoningSteps.length})</span>
|
||||
// </button>
|
||||
// </Tooltip.Trigger>
|
||||
// <Tooltip.Portal>
|
||||
// <Tooltip.Content
|
||||
// className="bg-white p-4 rounded-lg shadow-lg border border-gray-200 max-w-md z-50"
|
||||
// sideOffset={5}
|
||||
// >
|
||||
// <div className="max-h-[300px] overflow-y-auto">
|
||||
// <h4 className="font-semibold mb-2 text-gray-900">Reasoning Steps:</h4>
|
||||
// <ul className="space-y-3">
|
||||
// {reasoningSteps.map((step) => (
|
||||
// <li key={step.id} className="text-sm">
|
||||
// <div className="font-medium text-gray-900 capitalize">
|
||||
// {step.type}:
|
||||
// </div>
|
||||
// <div className="text-gray-700 ml-2">{step.content}</div>
|
||||
// {step.thought && step.thought !== step.content && (
|
||||
// <div className="text-gray-500 ml-2 mt-1 text-xs">
|
||||
// Thought: {step.thought}
|
||||
// </div>
|
||||
// )}
|
||||
// </li>
|
||||
// ))}
|
||||
// </ul>
|
||||
// </div>
|
||||
// <Tooltip.Arrow className="fill-white" />
|
||||
// </Tooltip.Content>
|
||||
// </Tooltip.Portal>
|
||||
// </Tooltip.Root>
|
||||
// </Tooltip.Provider>
|
||||
// )}
|
||||
// </div>
|
||||
// </div>
|
||||
// </div>
|
||||
// );
|
||||
// };
|
||||
|
||||
|
||||
|
||||
const handleSubmit = async () => {
|
||||
if (!inputMessage.trim()) return;
|
||||
|
||||
const currentMessage = inputMessage;
|
||||
setMessages(prev => [...prev, { role: 'user', content: currentMessage }]);
|
||||
setInputMessage('');
|
||||
setIsProcessing(true);
|
||||
|
||||
try {
|
||||
const response = await fetch(`${BACKEND_URL}/chat`, {
|
||||
...fetchDefaults,
|
||||
method: 'POST',
|
||||
body: JSON.stringify({ message: currentMessage })
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to get response');
|
||||
}
|
||||
|
||||
const responseData = await response.json();
|
||||
|
||||
// Log the full response data for debugging
|
||||
console.log('Full response data:', responseData);
|
||||
|
||||
// Extract the relevant data based on the response structure
|
||||
let assistantMessage;
|
||||
if (responseData.status === 'success' && responseData.data) {
|
||||
// New response structure
|
||||
assistantMessage = {
|
||||
role: 'assistant',
|
||||
content: responseData.data.response || '',
|
||||
sources: responseData.data.sources || [],
|
||||
reasoning: responseData.data.reasoning || []
|
||||
};
|
||||
} else if (responseData.result) {
|
||||
// Alternative response structure
|
||||
assistantMessage = {
|
||||
role: 'assistant',
|
||||
content: responseData.result.response || '',
|
||||
sources: responseData.result.sources || [],
|
||||
reasoning: responseData.result.reasoning || []
|
||||
};
|
||||
} else {
|
||||
// Fallback for direct response structure
|
||||
assistantMessage = {
|
||||
role: 'assistant',
|
||||
content: responseData.response || '',
|
||||
sources: responseData.sources || [],
|
||||
reasoning: responseData.reasoning || []
|
||||
};
|
||||
}
|
||||
|
||||
// Log the processed message for debugging
|
||||
console.log('Processed assistant message:', assistantMessage);
|
||||
|
||||
setMessages(prev => [...prev, assistantMessage]);
|
||||
setError(null);
|
||||
|
||||
} catch (e) {
|
||||
console.error('Error in chat:', e);
|
||||
setError('Failed to process response. Please try again.');
|
||||
} finally {
|
||||
setIsProcessing(false);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
const initializeChat = async () => {
|
||||
if (briefFiles.length === 0) {
|
||||
setError('Please upload at least one brief file.');
|
||||
return;
|
||||
}
|
||||
|
||||
logToConsole('info', 'Starting chat initialization', {
|
||||
briefFilesCount: briefFiles.length,
|
||||
supportingFilesCount: supportingFiles.length,
|
||||
briefFiles: briefFiles.map(f => ({
|
||||
name: f.name,
|
||||
type: f.type,
|
||||
size: f.size
|
||||
})),
|
||||
supportingFiles: supportingFiles.map(f => ({
|
||||
name: f.name,
|
||||
type: f.type,
|
||||
size: f.size
|
||||
}))
|
||||
});
|
||||
|
||||
setIsProcessing(true);
|
||||
const formData = new FormData();
|
||||
|
||||
briefFiles.forEach(file => {
|
||||
formData.append('brief', file);
|
||||
logToConsole('debug', 'Appending brief file to FormData', {
|
||||
fileName: file.name,
|
||||
fileType: file.type,
|
||||
fileSize: file.size
|
||||
});
|
||||
});
|
||||
|
||||
supportingFiles.forEach(file => {
|
||||
formData.append('supporting', file);
|
||||
logToConsole('debug', 'Appending supporting file to FormData', {
|
||||
fileName: file.name,
|
||||
fileType: file.type,
|
||||
fileSize: file.size
|
||||
});
|
||||
});
|
||||
|
||||
try {
|
||||
logToConsole('info', 'Sending initialization request to backend');
|
||||
|
||||
const response = await fetch(`${BACKEND_URL}/initialize`, {
|
||||
method: 'POST',
|
||||
mode: 'cors',
|
||||
credentials: 'include',
|
||||
body: formData,
|
||||
// headers: {
|
||||
// 'Accept': '*/*',
|
||||
// 'Access-Control-Allow-Origin': '*',
|
||||
// }
|
||||
});
|
||||
|
||||
logToConsole('debug', 'Received response from backend', {
|
||||
status: response.status,
|
||||
statusText: response.statusText,
|
||||
headers: Object.fromEntries(response.headers.entries())
|
||||
});
|
||||
|
||||
const textResponse = await response.text();
|
||||
logToConsole('debug', 'Response text received', { textResponse });
|
||||
|
||||
let data;
|
||||
try {
|
||||
data = JSON.parse(textResponse);
|
||||
logToConsole('debug', 'Successfully parsed response JSON', { data });
|
||||
} catch (e) {
|
||||
logToConsole('error', 'Failed to parse response JSON', {
|
||||
error: e.message,
|
||||
textResponse
|
||||
});
|
||||
throw new Error('Invalid response format from server');
|
||||
}
|
||||
|
||||
if (!response.ok) {
|
||||
logToConsole('error', 'Backend returned error response', {
|
||||
status: response.status,
|
||||
error: data.error
|
||||
});
|
||||
throw new Error(data.error || 'Failed to initialize chat');
|
||||
}
|
||||
|
||||
setIsInitialized(true);
|
||||
setMessages([{
|
||||
role: 'assistant',
|
||||
content: 'Chat initialized! How can I help you?'
|
||||
}]);
|
||||
|
||||
setError(null);
|
||||
logToConsole('info', 'Chat initialization completed successfully');
|
||||
|
||||
} catch (err) {
|
||||
logToConsole('error', 'Chat initialization failed', {
|
||||
error: err.message,
|
||||
stack: err.stack
|
||||
});
|
||||
|
||||
let errorMessage = 'Failed to initialize chat. ';
|
||||
if (err.message) {
|
||||
errorMessage += err.message;
|
||||
} else {
|
||||
errorMessage += 'Please try again.';
|
||||
}
|
||||
|
||||
setError(errorMessage);
|
||||
setIsInitialized(false);
|
||||
} finally {
|
||||
setIsProcessing(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Tooltip.Provider>
|
||||
<div className="flex flex-col h-screen max-w-4xl mx-auto p-4">
|
||||
{!isInitialized ? (
|
||||
<div className="space-y-4">
|
||||
<h1 className="text-2xl font-bold mb-4">Upload Files to Start Chat</h1>
|
||||
|
||||
<div className="space-y-2">
|
||||
<h2 className="text-lg font-semibold">Brief Files (Required)</h2>
|
||||
<div className="flex items-center space-x-2">
|
||||
<input
|
||||
type="file"
|
||||
ref={fileInputBriefRef}
|
||||
onChange={(e) => handleFilesChange(e, 'brief')}
|
||||
multiple
|
||||
accept=".pdf,.doc,.docx,.txt,.xls,.xlsx,.ppt,.pptx,.eml"
|
||||
className="hidden"
|
||||
/>
|
||||
<button
|
||||
onClick={() => fileInputBriefRef.current.click()}
|
||||
className="flex items-center space-x-2 px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600"
|
||||
>
|
||||
<Upload size={16} />
|
||||
<span>Upload Brief Files</span>
|
||||
</button>
|
||||
</div>
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
{briefFiles.map((file) => (
|
||||
<div key={file.name} className="flex items-center space-x-2 bg-gray-100 p-2 rounded">
|
||||
<FileText size={16} />
|
||||
<span className="truncate flex-1">{file.name}</span>
|
||||
<button
|
||||
onClick={() => removeFile(file.name, 'brief')}
|
||||
className="text-red-500 hover:text-red-700"
|
||||
>
|
||||
<X size={16} />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<h2 className="text-lg font-semibold">Supporting Files (Optional)</h2>
|
||||
<div className="flex items-center space-x-2">
|
||||
<input
|
||||
type="file"
|
||||
ref={fileInputSupportingRef}
|
||||
onChange={(e) => handleFilesChange(e, 'supporting')}
|
||||
multiple
|
||||
accept=".pdf,.doc,.docx,.txt,.xls,.xlsx,.ppt,.pptx,.eml"
|
||||
className="hidden"
|
||||
/>
|
||||
<button
|
||||
onClick={() => fileInputSupportingRef.current.click()}
|
||||
className="flex items-center space-x-2 px-4 py-2 bg-gray-500 text-white rounded hover:bg-gray-600"
|
||||
>
|
||||
<Upload size={16} />
|
||||
<span>Upload Supporting Files</span>
|
||||
</button>
|
||||
</div>
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
{supportingFiles.map((file) => (
|
||||
<div key={file.name} className="flex items-center space-x-2 bg-gray-100 p-2 rounded">
|
||||
<FileText size={16} />
|
||||
<span className="truncate flex-1">{file.name}</span>
|
||||
<button
|
||||
onClick={() => removeFile(file.name, 'supporting')}
|
||||
className="text-red-500 hover:text-red-700"
|
||||
>
|
||||
<X size={16} />
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button
|
||||
onClick={initializeChat}
|
||||
disabled={isProcessing || briefFiles.length === 0}
|
||||
className="w-full px-4 py-2 bg-green-500 text-white rounded hover:bg-green-600 disabled:bg-gray-300 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isProcessing ? (
|
||||
<div className="flex items-center justify-center space-x-2">
|
||||
<Loader2 className="animate-spin" size={16} />
|
||||
<span>Processing...</span>
|
||||
</div>
|
||||
) : (
|
||||
'Initialize Chat'
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<div className="flex-1 overflow-y-auto space-y-4 mb-4">
|
||||
{messages.map((message, index) => (
|
||||
<MessageBubble key={index} message={message} />
|
||||
))}
|
||||
<div ref={messagesEndRef} />
|
||||
</div>
|
||||
|
||||
<div className="flex space-x-2">
|
||||
<input
|
||||
type="text"
|
||||
value={inputMessage}
|
||||
onChange={(e) => setInputMessage(e.target.value)}
|
||||
onKeyPress={(e) => e.key === 'Enter' && handleSubmit()}
|
||||
placeholder="Type your message..."
|
||||
className="flex-1 p-2 border rounded focus:outline-none focus:ring-2 focus:ring-blue-500"
|
||||
disabled={isProcessing}
|
||||
/>
|
||||
<button
|
||||
onClick={handleSubmit}
|
||||
disabled={isProcessing || !inputMessage.trim()}
|
||||
className="px-4 py-2 bg-blue-500 text-white rounded hover:bg-blue-600 disabled:bg-gray-300 disabled:cursor-not-allowed"
|
||||
>
|
||||
{isProcessing ? (
|
||||
<Loader2 className="animate-spin" size={16} />
|
||||
) : (
|
||||
<Send size={16} />
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<Alert variant="destructive" className="mt-4">
|
||||
<AlertDescription>{error}</AlertDescription>
|
||||
</Alert>
|
||||
)}
|
||||
</div>
|
||||
</Tooltip.Provider>
|
||||
);
|
||||
}
|
||||
1
chat-interface/src/assets/react.svg
Normal file
1
chat-interface/src/assets/react.svg
Normal file
|
|
@ -0,0 +1 @@
|
|||
<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" aria-hidden="true" role="img" class="iconify iconify--logos" width="35.93" height="32" preserveAspectRatio="xMidYMid meet" viewBox="0 0 256 228"><path fill="#00D8FF" d="M210.483 73.824a171.49 171.49 0 0 0-8.24-2.597c.465-1.9.893-3.777 1.273-5.621c6.238-30.281 2.16-54.676-11.769-62.708c-13.355-7.7-35.196.329-57.254 19.526a171.23 171.23 0 0 0-6.375 5.848a155.866 155.866 0 0 0-4.241-3.917C100.759 3.829 77.587-4.822 63.673 3.233C50.33 10.957 46.379 33.89 51.995 62.588a170.974 170.974 0 0 0 1.892 8.48c-3.28.932-6.445 1.924-9.474 2.98C17.309 83.498 0 98.307 0 113.668c0 15.865 18.582 31.778 46.812 41.427a145.52 145.52 0 0 0 6.921 2.165a167.467 167.467 0 0 0-2.01 9.138c-5.354 28.2-1.173 50.591 12.134 58.266c13.744 7.926 36.812-.22 59.273-19.855a145.567 145.567 0 0 0 5.342-4.923a168.064 168.064 0 0 0 6.92 6.314c21.758 18.722 43.246 26.282 56.54 18.586c13.731-7.949 18.194-32.003 12.4-61.268a145.016 145.016 0 0 0-1.535-6.842c1.62-.48 3.21-.974 4.76-1.488c29.348-9.723 48.443-25.443 48.443-41.52c0-15.417-17.868-30.326-45.517-39.844Zm-6.365 70.984c-1.4.463-2.836.91-4.3 1.345c-3.24-10.257-7.612-21.163-12.963-32.432c5.106-11 9.31-21.767 12.459-31.957c2.619.758 5.16 1.557 7.61 2.4c23.69 8.156 38.14 20.213 38.14 29.504c0 9.896-15.606 22.743-40.946 31.14Zm-10.514 20.834c2.562 12.94 2.927 24.64 1.23 33.787c-1.524 8.219-4.59 13.698-8.382 15.893c-8.067 4.67-25.32-1.4-43.927-17.412a156.726 156.726 0 0 1-6.437-5.87c7.214-7.889 14.423-17.06 21.459-27.246c12.376-1.098 24.068-2.894 34.671-5.345a134.17 134.17 0 0 1 1.386 6.193ZM87.276 214.515c-7.882 2.783-14.16 2.863-17.955.675c-8.075-4.657-11.432-22.636-6.853-46.752a156.923 156.923 0 0 1 1.869-8.499c10.486 2.32 22.093 3.988 34.498 4.994c7.084 9.967 14.501 19.128 21.976 27.15a134.668 134.668 0 0 1-4.877 4.492c-9.933 8.682-19.886 14.842-28.658 17.94ZM50.35 144.747c-12.483-4.267-22.792-9.812-29.858-15.863c-6.35-5.437-9.555-10.836-9.555-15.216c0-9.322 13.897-21.212 37.076-29.293c2.813-.98 5.757-1.905 8.812-2.773c3.204 10.42 7.406 21.315 12.477 32.332c-5.137 11.18-9.399 22.249-12.634 32.792a134.718 134.718 0 0 1-6.318-1.979Zm12.378-84.26c-4.811-24.587-1.616-43.134 6.425-47.789c8.564-4.958 27.502 2.111 47.463 19.835a144.318 144.318 0 0 1 3.841 3.545c-7.438 7.987-14.787 17.08-21.808 26.988c-12.04 1.116-23.565 2.908-34.161 5.309a160.342 160.342 0 0 1-1.76-7.887Zm110.427 27.268a347.8 347.8 0 0 0-7.785-12.803c8.168 1.033 15.994 2.404 23.343 4.08c-2.206 7.072-4.956 14.465-8.193 22.045a381.151 381.151 0 0 0-7.365-13.322Zm-45.032-43.861c5.044 5.465 10.096 11.566 15.065 18.186a322.04 322.04 0 0 0-30.257-.006c4.974-6.559 10.069-12.652 15.192-18.18ZM82.802 87.83a323.167 323.167 0 0 0-7.227 13.238c-3.184-7.553-5.909-14.98-8.134-22.152c7.304-1.634 15.093-2.97 23.209-3.984a321.524 321.524 0 0 0-7.848 12.897Zm8.081 65.352c-8.385-.936-16.291-2.203-23.593-3.793c2.26-7.3 5.045-14.885 8.298-22.6a321.187 321.187 0 0 0 7.257 13.246c2.594 4.48 5.28 8.868 8.038 13.147Zm37.542 31.03c-5.184-5.592-10.354-11.779-15.403-18.433c4.902.192 9.899.29 14.978.29c5.218 0 10.376-.117 15.453-.343c-4.985 6.774-10.018 12.97-15.028 18.486Zm52.198-57.817c3.422 7.8 6.306 15.345 8.596 22.52c-7.422 1.694-15.436 3.058-23.88 4.071a382.417 382.417 0 0 0 7.859-13.026a347.403 347.403 0 0 0 7.425-13.565Zm-16.898 8.101a358.557 358.557 0 0 1-12.281 19.815a329.4 329.4 0 0 1-23.444.823c-7.967 0-15.716-.248-23.178-.732a310.202 310.202 0 0 1-12.513-19.846h.001a307.41 307.41 0 0 1-10.923-20.627a310.278 310.278 0 0 1 10.89-20.637l-.001.001a307.318 307.318 0 0 1 12.413-19.761c7.613-.576 15.42-.876 23.31-.876H128c7.926 0 15.743.303 23.354.883a329.357 329.357 0 0 1 12.335 19.695a358.489 358.489 0 0 1 11.036 20.54a329.472 329.472 0 0 1-11 20.722Zm22.56-122.124c8.572 4.944 11.906 24.881 6.52 51.026c-.344 1.668-.73 3.367-1.15 5.09c-10.622-2.452-22.155-4.275-34.23-5.408c-7.034-10.017-14.323-19.124-21.64-27.008a160.789 160.789 0 0 1 5.888-5.4c18.9-16.447 36.564-22.941 44.612-18.3ZM128 90.808c12.625 0 22.86 10.235 22.86 22.86s-10.235 22.86-22.86 22.86s-22.86-10.235-22.86-22.86s10.235-22.86 22.86-22.86Z"></path></svg>
|
||||
|
After Width: | Height: | Size: 4 KiB |
119
chat-interface/src/auth.js
Normal file
119
chat-interface/src/auth.js
Normal file
|
|
@ -0,0 +1,119 @@
|
|||
// MSAL Authentication Configuration
|
||||
const msalConfig = {
|
||||
auth: {
|
||||
clientId: "9079054c-9620-4757-a256-23413042f1ef",
|
||||
authority: "https://login.microsoftonline.com/e519c2e6-bc6d-4fdf-8d9c-923c2f002385",
|
||||
redirectUri: "https://ai-sandbox.oliver.solutions/format"
|
||||
},
|
||||
cache: {
|
||||
cacheLocation: "sessionStorage",
|
||||
storeAuthStateInCookie: true,
|
||||
}
|
||||
};
|
||||
|
||||
const loginRequest = {
|
||||
scopes: ["user.read"]
|
||||
};
|
||||
|
||||
// Initialize MSAL object
|
||||
let myMSALObj = null;
|
||||
let thisUser = null;
|
||||
|
||||
// Initialize MSAL when loaded
|
||||
export function initializeMSAL() {
|
||||
if (typeof msal !== 'undefined') {
|
||||
myMSALObj = new msal.PublicClientApplication(msalConfig);
|
||||
|
||||
// Check if there's a cached user
|
||||
const accounts = myMSALObj.getAllAccounts();
|
||||
if (accounts.length > 0) {
|
||||
thisUser = accounts[0].username;
|
||||
// User is already logged in, show content
|
||||
showProtectedContent();
|
||||
} else {
|
||||
// Need to initialize login elements if not logged in
|
||||
const signinButton = document.getElementById('signin-button');
|
||||
if (signinButton) {
|
||||
signinButton.addEventListener('click', signIn);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
console.error("MSAL library not loaded");
|
||||
}
|
||||
}
|
||||
|
||||
// Sign in with popup
|
||||
export function signIn() {
|
||||
if (!myMSALObj) {
|
||||
initializeMSAL();
|
||||
}
|
||||
|
||||
myMSALObj.loginPopup(loginRequest)
|
||||
.then(loginResponse => {
|
||||
console.log("User logged in:", loginResponse.account.username);
|
||||
thisUser = loginResponse.account.username;
|
||||
sessionStorage.setItem('msalUsername', loginResponse.account.username);
|
||||
showProtectedContent();
|
||||
}).catch(error => {
|
||||
console.error("Error during login:", error);
|
||||
});
|
||||
}
|
||||
|
||||
// Show the protected content (the chat interface)
|
||||
function showProtectedContent() {
|
||||
const loginContainer = document.getElementById('login-container');
|
||||
const rootContainer = document.getElementById('root');
|
||||
|
||||
if (loginContainer) {
|
||||
loginContainer.style.display = 'none';
|
||||
}
|
||||
|
||||
if (rootContainer) {
|
||||
rootContainer.style.display = 'block';
|
||||
}
|
||||
|
||||
// Dispatch an event that authentication is complete
|
||||
window.dispatchEvent(new CustomEvent('authenticationComplete', {
|
||||
detail: {
|
||||
username: thisUser
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
// Get current user
|
||||
export function getCurrentUser() {
|
||||
if (!thisUser) {
|
||||
thisUser = sessionStorage.getItem('msalUsername');
|
||||
}
|
||||
return thisUser;
|
||||
}
|
||||
|
||||
// Check if user is authenticated
|
||||
export function isAuthenticated() {
|
||||
if (!myMSALObj) {
|
||||
initializeMSAL();
|
||||
}
|
||||
|
||||
return myMSALObj?.getAllAccounts().length > 0 || !!getCurrentUser();
|
||||
}
|
||||
|
||||
// Sign out
|
||||
export function signOut() {
|
||||
if (!myMSALObj) {
|
||||
initializeMSAL();
|
||||
}
|
||||
|
||||
const logoutRequest = {
|
||||
account: myMSALObj.getAccountByUsername(thisUser)
|
||||
};
|
||||
|
||||
myMSALObj.logout(logoutRequest).then(() => {
|
||||
thisUser = null;
|
||||
sessionStorage.removeItem('msalUsername');
|
||||
// Redirect to login page
|
||||
window.location.reload();
|
||||
});
|
||||
}
|
||||
|
||||
// Register initialization on window load
|
||||
window.addEventListener('DOMContentLoaded', initializeMSAL);
|
||||
140
chat-interface/src/components/ChatInterface.jsx
Normal file
140
chat-interface/src/components/ChatInterface.jsx
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
// Use environment variables for backend URL
|
||||
// Define backend URL dynamically based on environment
|
||||
const BACKEND_URL = import.meta.env.VITE_BACKEND_URL || 'https://ai-sandbox.oliver.solutions/hp_back_v2';
|
||||
console.log('ChatInterface using backend URL:', BACKEND_URL);
|
||||
|
||||
const initializeChat = async ({ sessionId, briefFiles, supportingFiles, setError, setIsProcessing, setIsInitialized, setMessages }) => {
|
||||
if (briefFiles.length === 0) {
|
||||
setError('Please upload at least one brief file.');
|
||||
return;
|
||||
}
|
||||
|
||||
setIsProcessing(true);
|
||||
const formData = new FormData();
|
||||
|
||||
// Add session ID to the form data
|
||||
formData.append('sessionId', sessionId);
|
||||
|
||||
briefFiles.forEach(file => {
|
||||
formData.append('brief', file);
|
||||
});
|
||||
|
||||
supportingFiles.forEach(file => {
|
||||
formData.append('supporting', file);
|
||||
});
|
||||
|
||||
try {
|
||||
console.log('Sending initialize request...');
|
||||
const response = await fetch(`${BACKEND_URL}/initialize`, {
|
||||
method: 'POST',
|
||||
body: formData, // No Content-Type header for FormData
|
||||
credentials: 'include', // Include credentials (cookies, auth headers)
|
||||
});
|
||||
|
||||
console.log('Response status:', response.status);
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json();
|
||||
throw new Error(errorData.error || 'Failed to initialize chat');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log('Response data:', data);
|
||||
|
||||
setIsInitialized(true);
|
||||
setMessages([{
|
||||
role: 'assistant',
|
||||
content: 'Chat initialized! How can I help you?'
|
||||
}]);
|
||||
setError(null); // Clear any previous errors
|
||||
} catch (err) {
|
||||
console.error('Error:', err);
|
||||
setError(err.message || 'Failed to initialize chat. Please try again.');
|
||||
} finally {
|
||||
setIsProcessing(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSubmit = async ({ inputMessage, sessionId, setMessages, setInputMessage, setIsProcessing, setError, logToConsole }) => {
|
||||
if (!inputMessage.trim()) return;
|
||||
|
||||
const currentMessage = inputMessage;
|
||||
setMessages(prev => [...prev, { role: 'user', content: currentMessage }]);
|
||||
setInputMessage('');
|
||||
setIsProcessing(true);
|
||||
|
||||
try {
|
||||
const response = await fetch(`${BACKEND_URL}/chat`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
body: JSON.stringify({
|
||||
message: currentMessage,
|
||||
sessionId: sessionId
|
||||
}),
|
||||
credentials: 'include', // Include credentials for cross-origin
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json();
|
||||
throw new Error(errorData.error || 'Failed to get response');
|
||||
}
|
||||
|
||||
const dataString = await response.text(); // Get the response as text
|
||||
const data = JSON.parse(dataString); // Parse the JSON string
|
||||
|
||||
// Extract response data, checking for different response formats
|
||||
let responseText = '';
|
||||
let sources = [];
|
||||
let reasoning = [];
|
||||
let images = [];
|
||||
|
||||
if (data.data) {
|
||||
// New response format
|
||||
responseText = data.data.response || '';
|
||||
sources = data.data.sources || [];
|
||||
reasoning = data.data.reasoning || [];
|
||||
images = data.data.images || [];
|
||||
} else {
|
||||
// Direct response format
|
||||
responseText = data.response || '';
|
||||
sources = data.sources || [];
|
||||
reasoning = data.reasoning || [];
|
||||
images = data.images || [];
|
||||
}
|
||||
|
||||
// Log the images
|
||||
if (images && images.length > 0) {
|
||||
logToConsole('info', 'Images received in response', { imageCount: images.length, images });
|
||||
}
|
||||
|
||||
setMessages(prev => {
|
||||
const newMessages = [...prev];
|
||||
newMessages.push({
|
||||
role: 'assistant',
|
||||
content: responseText,
|
||||
sources: sources,
|
||||
reasoning: reasoning,
|
||||
images: images
|
||||
});
|
||||
return newMessages;
|
||||
});
|
||||
setError(null); // Clear any previous error
|
||||
|
||||
} catch (e) { // Include the original JSON string in the error data
|
||||
console.error('Error processing chat response:', e);
|
||||
logToConsole('error', 'Chat error', {
|
||||
error: e.message,
|
||||
stack: e.stack,
|
||||
response: e.response
|
||||
});
|
||||
|
||||
setError('Failed to process response. Please check the logs for details.'); // Display a more user friendly error in UI
|
||||
} finally {
|
||||
setIsProcessing(false);
|
||||
}
|
||||
};
|
||||
|
||||
// Export the functions
|
||||
export { initializeChat, handleSubmit };
|
||||
206
chat-interface/src/components/ConversationManager.jsx
Normal file
206
chat-interface/src/components/ConversationManager.jsx
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
// Get the backend URL from environment or use correct path as fallback
|
||||
// Define backend URL dynamically based on environment
|
||||
const BACKEND_URL = import.meta.env.VITE_BACKEND_URL || 'https://ai-sandbox.oliver.solutions/hp_back_v2';
|
||||
console.log('ConversationManager using backend URL:', BACKEND_URL);
|
||||
|
||||
// Function to load user conversations
|
||||
export const loadUserConversations = async (username, setConversations, setActiveConversation, loadConversation, setError) => {
|
||||
try {
|
||||
const response = await fetch(`${BACKEND_URL}/conversations`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-MS-USERNAME': username || '' // Add authenticated username to request
|
||||
},
|
||||
credentials: 'include'
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to load conversations');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.status === 'success') {
|
||||
setConversations(data.conversations || []);
|
||||
|
||||
// If we have conversations and none are active, select the most recent one
|
||||
if (data.conversations && data.conversations.length > 0 && setActiveConversation) {
|
||||
// Sort by last_updated, newest first
|
||||
const sortedConversations = [...data.conversations].sort((a, b) =>
|
||||
new Date(b.last_updated) - new Date(a.last_updated)
|
||||
);
|
||||
|
||||
// Set the most recent conversation as active
|
||||
setActiveConversation(sortedConversations[0]);
|
||||
|
||||
// Load this conversation if load function provided
|
||||
if (loadConversation) {
|
||||
await loadConversation(sortedConversations[0]);
|
||||
}
|
||||
}
|
||||
|
||||
return data.conversations || [];
|
||||
} else {
|
||||
throw new Error(data.error || 'Failed to load conversations');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error loading conversations:', error);
|
||||
if (setError) {
|
||||
setError('Failed to load conversations. Please try again later.');
|
||||
}
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
// Function to load a specific conversation's messages
|
||||
export const loadConversationMessages = async (conversation, username, setMessages, setSessionId, setError) => {
|
||||
try {
|
||||
if (!conversation || !conversation.id) {
|
||||
throw new Error('Invalid conversation');
|
||||
}
|
||||
|
||||
const response = await fetch(`${BACKEND_URL}/conversations/${conversation.id}/messages`, {
|
||||
method: 'GET',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-MS-USERNAME': username || '' // Add authenticated username to request
|
||||
},
|
||||
credentials: 'include'
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to load conversation messages');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.status === 'success') {
|
||||
// Format the messages for the UI
|
||||
const formattedMessages = data.messages.map(msg => ({
|
||||
role: msg.role,
|
||||
content: msg.content,
|
||||
sources: msg.sources || [],
|
||||
reasoning: msg.reasoning || [],
|
||||
images: msg.images || []
|
||||
}));
|
||||
|
||||
setMessages(formattedMessages);
|
||||
|
||||
// Set the session ID if provided
|
||||
if (setSessionId && conversation.session_id) {
|
||||
setSessionId(conversation.session_id);
|
||||
localStorage.setItem('chatSessionId', conversation.session_id);
|
||||
}
|
||||
|
||||
return formattedMessages;
|
||||
} else {
|
||||
throw new Error(data.error || 'Failed to load conversation messages');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error loading conversation messages:', error);
|
||||
if (setError) {
|
||||
setError('Failed to load conversation messages. Please try again later.');
|
||||
}
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
// Function to create a new conversation
|
||||
export const createNewConversation = async (username, setConversations, setActiveConversation, setSessionId, setMessages, setError) => {
|
||||
try {
|
||||
const response = await fetch(`${BACKEND_URL}/conversations/new`, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-MS-USERNAME': username || '' // Add authenticated username to request
|
||||
},
|
||||
credentials: 'include'
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to create new conversation');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.status === 'success') {
|
||||
// Create a new conversation object
|
||||
const newConversation = {
|
||||
id: data.conversation_id,
|
||||
title: "New conversation",
|
||||
created_at: new Date().toISOString(),
|
||||
last_updated: new Date().toISOString(),
|
||||
session_id: data.session_id
|
||||
};
|
||||
|
||||
// Update the conversations list
|
||||
setConversations(prev => [newConversation, ...prev]);
|
||||
|
||||
// Set as active conversation
|
||||
setActiveConversation(newConversation);
|
||||
|
||||
// Set the session ID
|
||||
if (setSessionId) {
|
||||
setSessionId(data.session_id);
|
||||
localStorage.setItem('chatSessionId', data.session_id);
|
||||
}
|
||||
|
||||
// Reset messages with welcome message
|
||||
if (setMessages) {
|
||||
setMessages([{
|
||||
role: 'assistant',
|
||||
content: 'How can I help you today?'
|
||||
}]);
|
||||
}
|
||||
|
||||
return newConversation;
|
||||
} else {
|
||||
throw new Error(data.error || 'Failed to create new conversation');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error creating new conversation:', error);
|
||||
if (setError) {
|
||||
setError('Failed to create new conversation. Please try again later.');
|
||||
}
|
||||
return null;
|
||||
}
|
||||
};
|
||||
|
||||
// Function to delete a conversation
|
||||
export const deleteConversation = async (conversationId, username, setConversations, setError) => {
|
||||
try {
|
||||
if (!conversationId) {
|
||||
throw new Error('Invalid conversation ID');
|
||||
}
|
||||
|
||||
const response = await fetch(`${BACKEND_URL}/conversations/${conversationId}`, {
|
||||
method: 'DELETE',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
'X-MS-USERNAME': username || '' // Add authenticated username to request
|
||||
},
|
||||
credentials: 'include'
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to delete conversation');
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (data.status === 'success') {
|
||||
// Remove the conversation from the list
|
||||
setConversations(prev => prev.filter(convo => convo.id !== conversationId));
|
||||
return true;
|
||||
} else {
|
||||
throw new Error(data.error || 'Failed to delete conversation');
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error deleting conversation:', error);
|
||||
if (setError) {
|
||||
setError('Failed to delete conversation. Please try again later.');
|
||||
}
|
||||
return false;
|
||||
}
|
||||
};
|
||||
21
chat-interface/src/components/ThemeToggle.jsx
Normal file
21
chat-interface/src/components/ThemeToggle.jsx
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
import React from 'react';
|
||||
import { useTheme } from '../main';
|
||||
import { Moon, Sun } from 'lucide-react';
|
||||
|
||||
export default function ThemeToggle() {
|
||||
const { darkMode, setDarkMode } = useTheme();
|
||||
|
||||
return (
|
||||
<button
|
||||
onClick={() => setDarkMode(!darkMode)}
|
||||
className="p-2 rounded-full bg-gray-700/40 hover:bg-gray-700/60 transition-colors duration-200 ease-in-out"
|
||||
aria-label={darkMode ? "Switch to light mode" : "Switch to dark mode"}
|
||||
>
|
||||
{darkMode ? (
|
||||
<Sun size={20} className="text-yellow-400" />
|
||||
) : (
|
||||
<Moon size={20} className="text-gray-200" />
|
||||
)}
|
||||
</button>
|
||||
);
|
||||
}
|
||||
37
chat-interface/src/components/ui/alert.jsx
Normal file
37
chat-interface/src/components/ui/alert.jsx
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import * as React from "react"
|
||||
import { cn } from "../../lib/utils"
|
||||
|
||||
const Alert = React.forwardRef(({ className, variant = "default", children, ...props }, ref) => {
|
||||
const variantClasses = {
|
||||
default: "bg-gray-100 text-gray-900",
|
||||
destructive: "bg-red-100 text-red-900",
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
role="alert"
|
||||
className={cn(
|
||||
"relative w-full rounded-lg border p-4",
|
||||
variantClasses[variant],
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
)
|
||||
})
|
||||
Alert.displayName = "Alert"
|
||||
|
||||
const AlertDescription = React.forwardRef(({ className, ...props }, ref) => (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn("text-sm [&_p]:leading-relaxed", className)}
|
||||
{...props}
|
||||
/>
|
||||
))
|
||||
AlertDescription.displayName = "AlertDescription"
|
||||
|
||||
export { Alert, AlertDescription }
|
||||
|
||||
131
chat-interface/src/index.css
Normal file
131
chat-interface/src/index.css
Normal file
|
|
@ -0,0 +1,131 @@
|
|||
@tailwind base;
|
||||
@tailwind components;
|
||||
@tailwind utilities;
|
||||
|
||||
@layer base {
|
||||
:root {
|
||||
--background: 0 0% 100%;
|
||||
--foreground: 222.2 84% 4.9%;
|
||||
--muted: 210 40% 96.1%;
|
||||
--muted-foreground: 215.4 16.3% 46.9%;
|
||||
--border: 214.3 31.8% 91.4%;
|
||||
--radius: 0.5rem;
|
||||
}
|
||||
|
||||
.dark {
|
||||
--background: 222.2 84% 4.9%;
|
||||
--foreground: 210 40% 98%;
|
||||
--muted: 217.2 32.6% 17.5%;
|
||||
--muted-foreground: 215 20.2% 65.1%;
|
||||
--border: 217.2 32.6% 17.5%;
|
||||
}
|
||||
|
||||
body {
|
||||
@apply bg-black text-foreground;
|
||||
border: none !important;
|
||||
border-top: none !important;
|
||||
border-bottom: none !important;
|
||||
}
|
||||
}
|
||||
|
||||
/* Netflix theme colors */
|
||||
.netflix-red {
|
||||
color: #E50914;
|
||||
}
|
||||
|
||||
.netflix-bg {
|
||||
background-color: #141414;
|
||||
}
|
||||
|
||||
/* Prevent any border from appearing */
|
||||
.no-borders, .no-borders::before, .no-borders::after {
|
||||
border: none !important;
|
||||
border-top: none !important;
|
||||
border-bottom: none !important;
|
||||
border-left: none !important;
|
||||
border-right: none !important;
|
||||
outline: none !important;
|
||||
}
|
||||
|
||||
/* Markdown content styling */
|
||||
.markdown-content {
|
||||
@apply text-base leading-relaxed;
|
||||
}
|
||||
|
||||
.markdown-content h1 {
|
||||
@apply text-2xl font-bold mt-4 mb-2;
|
||||
}
|
||||
|
||||
.markdown-content h2 {
|
||||
@apply text-xl font-bold mt-3 mb-2;
|
||||
}
|
||||
|
||||
.markdown-content h3 {
|
||||
@apply text-lg font-bold mt-3 mb-1;
|
||||
}
|
||||
|
||||
.markdown-content p {
|
||||
@apply my-2;
|
||||
}
|
||||
|
||||
.markdown-content ul {
|
||||
@apply list-disc pl-5 my-2;
|
||||
}
|
||||
|
||||
.markdown-content ol {
|
||||
@apply list-decimal pl-5 my-2;
|
||||
}
|
||||
|
||||
.markdown-content li {
|
||||
@apply my-1;
|
||||
}
|
||||
|
||||
.markdown-content a {
|
||||
@apply text-blue-600 hover:underline;
|
||||
}
|
||||
|
||||
.markdown-content code {
|
||||
@apply font-mono bg-gray-200 px-1 py-0.5 rounded text-sm;
|
||||
}
|
||||
|
||||
.markdown-content pre {
|
||||
@apply bg-gray-800 text-white p-3 rounded my-3 overflow-auto;
|
||||
}
|
||||
|
||||
.markdown-content pre code {
|
||||
@apply bg-transparent text-white text-sm;
|
||||
}
|
||||
|
||||
.markdown-content blockquote {
|
||||
@apply border-l-4 border-gray-300 pl-4 my-3 italic text-gray-700;
|
||||
}
|
||||
|
||||
.markdown-content table {
|
||||
@apply border-collapse border border-gray-300 my-3 w-full;
|
||||
}
|
||||
|
||||
.markdown-content th {
|
||||
@apply border border-gray-300 bg-gray-100 p-2 font-semibold;
|
||||
}
|
||||
|
||||
.markdown-content td {
|
||||
@apply border border-gray-300 p-2;
|
||||
}
|
||||
|
||||
/* Different background for assistant markdown vs user text */
|
||||
.bg-blue-500 .markdown-content {
|
||||
@apply text-white;
|
||||
}
|
||||
|
||||
.bg-blue-500 .markdown-content code {
|
||||
@apply bg-blue-600 text-white;
|
||||
}
|
||||
|
||||
.bg-blue-500 .markdown-content a {
|
||||
@apply text-white underline;
|
||||
}
|
||||
|
||||
.bg-blue-500 .markdown-content blockquote {
|
||||
@apply border-white/50 text-white/90;
|
||||
}
|
||||
|
||||
40
chat-interface/src/lib/utils.js
Normal file
40
chat-interface/src/lib/utils.js
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
// Original utils.js file content
|
||||
import { clsx } from "clsx"
|
||||
import { twMerge } from "tailwind-merge"
|
||||
|
||||
export function cn(...inputs) {
|
||||
return twMerge(clsx(inputs))
|
||||
}
|
||||
|
||||
/**
|
||||
* Enhanced fetch function with a 300-second timeout
|
||||
* @param {string} url - URL to fetch
|
||||
* @param {Object} options - Fetch options
|
||||
* @returns {Promise} - Fetch promise with timeout
|
||||
*/
|
||||
export async function fetchWithTimeout(url, options = {}) {
|
||||
const timeout = 300000; // 300 seconds in milliseconds
|
||||
|
||||
const controller = new AbortController();
|
||||
const { signal } = controller;
|
||||
|
||||
const timeoutId = setTimeout(() => {
|
||||
controller.abort();
|
||||
}, timeout);
|
||||
|
||||
try {
|
||||
const response = await fetch(url, {
|
||||
...options,
|
||||
signal,
|
||||
});
|
||||
|
||||
clearTimeout(timeoutId);
|
||||
return response;
|
||||
} catch (error) {
|
||||
clearTimeout(timeoutId);
|
||||
if (error.name === 'AbortError') {
|
||||
throw new Error(`Request timed out after ${timeout / 1000} seconds`);
|
||||
}
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
75
chat-interface/src/main.jsx
Normal file
75
chat-interface/src/main.jsx
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
import React, { useEffect, useState, createContext, useContext } from 'react'
|
||||
import ReactDOM from 'react-dom/client'
|
||||
import App from './App'
|
||||
import './index.css'
|
||||
import { initializeMSAL, isAuthenticated } from './auth'
|
||||
|
||||
// Initialize MSAL
|
||||
initializeMSAL();
|
||||
|
||||
// Create Theme Context
|
||||
export const ThemeContext = createContext();
|
||||
|
||||
export const ThemeProvider = ({ children }) => {
|
||||
const [darkMode, setDarkMode] = useState(() => {
|
||||
// Check local storage for user preference
|
||||
const savedMode = localStorage.getItem('darkMode');
|
||||
return savedMode === 'true';
|
||||
});
|
||||
|
||||
// Apply dark mode class to html element
|
||||
useEffect(() => {
|
||||
if (darkMode) {
|
||||
document.documentElement.classList.add('dark');
|
||||
} else {
|
||||
document.documentElement.classList.remove('dark');
|
||||
}
|
||||
// Save preference to localStorage
|
||||
localStorage.setItem('darkMode', darkMode);
|
||||
}, [darkMode]);
|
||||
|
||||
return (
|
||||
<ThemeContext.Provider value={{ darkMode, setDarkMode }}>
|
||||
{children}
|
||||
</ThemeContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
// Custom hook for using theme context
|
||||
export const useTheme = () => useContext(ThemeContext);
|
||||
|
||||
const AuthenticatedApp = () => {
|
||||
const [authenticated, setAuthenticated] = useState(isAuthenticated());
|
||||
|
||||
useEffect(() => {
|
||||
// Listen for authentication complete event
|
||||
const handleAuth = () => {
|
||||
setAuthenticated(true);
|
||||
};
|
||||
|
||||
window.addEventListener('authenticationComplete', handleAuth);
|
||||
|
||||
// Check if already authenticated
|
||||
if (isAuthenticated()) {
|
||||
handleAuth();
|
||||
}
|
||||
|
||||
return () => {
|
||||
window.removeEventListener('authenticationComplete', handleAuth);
|
||||
};
|
||||
}, []);
|
||||
|
||||
// The App component will only be rendered when user is authenticated
|
||||
return authenticated ? (
|
||||
<ThemeProvider>
|
||||
<App />
|
||||
</ThemeProvider>
|
||||
) : null;
|
||||
};
|
||||
|
||||
ReactDOM.createRoot(document.getElementById('root')).render(
|
||||
<React.StrictMode>
|
||||
<AuthenticatedApp />
|
||||
</React.StrictMode>,
|
||||
)
|
||||
|
||||
28
chat-interface/tailwind.config.js
Normal file
28
chat-interface/tailwind.config.js
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
/** @type {import('tailwindcss').Config} */
|
||||
export default {
|
||||
content: [
|
||||
"./index.html",
|
||||
"./src/**/*.{js,jsx,ts,tsx}",
|
||||
],
|
||||
darkMode: 'class', // Enable class-based dark mode
|
||||
theme: {
|
||||
extend: {
|
||||
colors: {
|
||||
border: "hsl(var(--border))",
|
||||
background: "hsl(var(--background))",
|
||||
foreground: "hsl(var(--foreground))",
|
||||
muted: {
|
||||
DEFAULT: "hsl(var(--muted))",
|
||||
foreground: "hsl(var(--muted-foreground))"
|
||||
},
|
||||
},
|
||||
borderRadius: {
|
||||
lg: "var(--radius)",
|
||||
md: "calc(var(--radius) - 2px)",
|
||||
sm: "calc(var(--radius) - 4px)",
|
||||
},
|
||||
},
|
||||
},
|
||||
plugins: [],
|
||||
}
|
||||
|
||||
38
chat-interface/update-backend.sh
Executable file
38
chat-interface/update-backend.sh
Executable file
|
|
@ -0,0 +1,38 @@
|
|||
#!/bin/bash
|
||||
|
||||
# This script updates the backend URL and rebuilds the application
|
||||
|
||||
# Define the backend URL
|
||||
BACKEND_URL="https://ai-sandbox.oliver.solutions/netflix_back_v2"
|
||||
|
||||
# Create or update the environment files
|
||||
echo "# Backend API URL
|
||||
VITE_BACKEND_URL=${BACKEND_URL}
|
||||
|
||||
# Base URL for the app (changes in production)
|
||||
VITE_APP_BASE_URL=/" > .env
|
||||
|
||||
echo "# Production backend API URL
|
||||
VITE_BACKEND_URL=${BACKEND_URL}
|
||||
|
||||
# Base URL for the app in production
|
||||
VITE_APP_BASE_URL=/netflix_v2/" > .env.production
|
||||
|
||||
# Update components files
|
||||
for file in src/components/ChatInterface.jsx src/components/ConversationManager.jsx src/App.jsx; do
|
||||
# Replace hardcoded backend URL references
|
||||
sed -i '' "s|const BACKEND_URL = ['\"]\/netflix_back_v2['\"]|const BACKEND_URL = import.meta.env.VITE_BACKEND_URL || '${BACKEND_URL}'|g" $file
|
||||
sed -i '' "s|console.log('.*using backend URL (hardcoded):|console.log('Using backend URL:|g" $file
|
||||
done
|
||||
|
||||
# Update vite.config.js to use the replacement plugin
|
||||
sed -i '' 's|plugins: \[react()\],|plugins: [react(), replaceBackendUrl],|g' vite.config.js
|
||||
sed -i '' 's|code: src.replace(/\['"'\]\/netflix_back\['"'\]/g, '"\/netflix_back_v2"'),|code: src.replace(/\['"'\]\/netflix_back\['"'\]/g, `"${BACKEND_URL}"`),|g' vite.config.js
|
||||
|
||||
# Force a complete rebuild by removing all assets
|
||||
rm -rf dist/assets/*
|
||||
|
||||
# Rebuild the application
|
||||
npm run build
|
||||
|
||||
echo "Update complete! The backend URL has been set to ${BACKEND_URL}"
|
||||
124
chat-interface/vite.config.js
Normal file
124
chat-interface/vite.config.js
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
import { defineConfig } from 'vite'
|
||||
import react from '@vitejs/plugin-react'
|
||||
import path from 'path'
|
||||
|
||||
export default defineConfig(({ mode }) => {
|
||||
const baseUrl = mode === 'production' ? '/hp_chatbot/' : '/';
|
||||
|
||||
// Create a replace plugin to ensure all '/hp_back' gets replaced with the full URL
|
||||
const replaceBackendUrl = {
|
||||
name: 'replace-backend-url',
|
||||
transform(src, id) {
|
||||
if (id.endsWith('.js') || id.endsWith('.jsx')) {
|
||||
return {
|
||||
code: src.replace(/['"]\/hp_back['"]/g, '"https://ai-sandbox.oliver.solutions/hp_back_v2"'),
|
||||
map: null
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
plugins: [react(), replaceBackendUrl],
|
||||
base: baseUrl,
|
||||
server: {
|
||||
port: 5173,
|
||||
proxy: {
|
||||
// New MongoDB conversation endpoints
|
||||
'/conversations': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
'/conversations/new': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
|
||||
// Original endpoints
|
||||
'/chat': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
'/status': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
'/initialize': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
'/reset': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
'/download-brief': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
'/init-chunked-upload': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
'/upload-chunk': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
'/finalize-upload': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
'/upload-small-file': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
'/initialize-from-uploads': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
'/images': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
'/list-images': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
},
|
||||
'/capture-screenshot': {
|
||||
target: 'https://ai-sandbox.oliver.solutions/hp_back_v2',
|
||||
changeOrigin: true,
|
||||
secure: true,
|
||||
timeout: 300000 // 5 minutes timeout
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
})
|
||||
|
||||
46
chat-interface/web.config
Normal file
46
chat-interface/web.config
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<configuration>
|
||||
<system.webServer>
|
||||
<staticContent>
|
||||
<!-- Set correct MIME types -->
|
||||
<remove fileExtension=".js" />
|
||||
<mimeMap fileExtension=".js" mimeType="application/javascript" />
|
||||
|
||||
<remove fileExtension=".json" />
|
||||
<mimeMap fileExtension=".json" mimeType="application/json" />
|
||||
|
||||
<remove fileExtension=".css" />
|
||||
<mimeMap fileExtension=".css" mimeType="text/css" />
|
||||
|
||||
<remove fileExtension=".svg" />
|
||||
<mimeMap fileExtension=".svg" mimeType="image/svg+xml" />
|
||||
|
||||
<remove fileExtension=".woff" />
|
||||
<mimeMap fileExtension=".woff" mimeType="font/woff" />
|
||||
|
||||
<remove fileExtension=".woff2" />
|
||||
<mimeMap fileExtension=".woff2" mimeType="font/woff2" />
|
||||
</staticContent>
|
||||
|
||||
<!-- Enable CORS -->
|
||||
<httpProtocol>
|
||||
<customHeaders>
|
||||
<add name="Access-Control-Allow-Origin" value="*" />
|
||||
</customHeaders>
|
||||
</httpProtocol>
|
||||
|
||||
<!-- URL Rewrite for SPA -->
|
||||
<rewrite>
|
||||
<rules>
|
||||
<rule name="SPA_Fallback" stopProcessing="true">
|
||||
<match url=".*" />
|
||||
<conditions logicalGrouping="MatchAll">
|
||||
<add input="{REQUEST_FILENAME}" matchType="IsFile" negate="true" />
|
||||
<add input="{REQUEST_FILENAME}" matchType="IsDirectory" negate="true" />
|
||||
</conditions>
|
||||
<action type="Rewrite" url="/netflix_chatbot/index.html" />
|
||||
</rule>
|
||||
</rules>
|
||||
</rewrite>
|
||||
</system.webServer>
|
||||
</configuration>
|
||||
103
config.py
Normal file
103
config.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
# hp_chatbot/config.py
|
||||
import os
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
import sys
|
||||
|
||||
# Load environment variables from .env file
|
||||
env_path = Path(__file__).resolve().parent / '.env'
|
||||
if env_path.exists():
|
||||
# Force reload to ensure environment variables are set
|
||||
load_dotenv(dotenv_path=env_path, override=True)
|
||||
print(f"Loaded environment variables from {env_path}")
|
||||
else:
|
||||
print(f"WARNING: .env file not found at {env_path}", file=sys.stderr)
|
||||
|
||||
# --- Directory Paths ---
|
||||
BASE_DIR = Path(__file__).resolve().parent
|
||||
|
||||
UPLOAD_DIR = BASE_DIR / 'uploads'
|
||||
CHUNK_FOLDER = UPLOAD_DIR / 'chunks'
|
||||
UPLOAD_METADATA_FOLDER = UPLOAD_DIR / 'metadata'
|
||||
IMAGES_DIRECTORY = UPLOAD_DIR / 'images'
|
||||
|
||||
SUPPORTING_FILES_DIR = BASE_DIR / 'supporting_files'
|
||||
HP_DOCS_FOLDER = SUPPORTING_FILES_DIR / 'files_for_rag_store'
|
||||
|
||||
INDEX_STORAGE_DIR = BASE_DIR / 'index_storage'
|
||||
INDEX_PERSIST_PATH = INDEX_STORAGE_DIR / "hp_docs_index"
|
||||
|
||||
LOG_FILE_PATH = BASE_DIR / 'app.log'
|
||||
|
||||
# Create necessary directories
|
||||
os.makedirs(CHUNK_FOLDER, exist_ok=True)
|
||||
os.makedirs(UPLOAD_METADATA_FOLDER, exist_ok=True)
|
||||
os.makedirs(INDEX_STORAGE_DIR, exist_ok=True)
|
||||
os.makedirs(IMAGES_DIRECTORY, exist_ok=True)
|
||||
|
||||
# --- Application Settings ---
|
||||
ALLOWED_EXTENSIONS = {'pdf', 'doc', 'docx', 'txt', 'xls', 'xlsx', 'ppt', 'pptx', 'eml'}
|
||||
APPLICATION_ROOT = os.environ.get('APPLICATION_ROOT', '') # For running behind proxy
|
||||
MAX_CONTENT_LENGTH = 500 * 1024 * 1024 # 500MB limit (adjust as needed)
|
||||
|
||||
# --- API Keys ---
|
||||
# Load from environment variables or use defaults (replace placeholders or set env vars)
|
||||
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "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")
|
||||
ANTHROPIC_API_KEY = os.environ.get("ANTHROPIC_API_KEY", "")
|
||||
LLAMA_CLOUD_API_KEY = os.environ.get("LLAMA_CLOUD_API_KEY", "")
|
||||
|
||||
# Ensure required keys are set
|
||||
if not OPENAI_API_KEY:
|
||||
print("ERROR: OPENAI_API_KEY not set in environment or .env file. This is required.", file=sys.stderr)
|
||||
print("Please add OPENAI_API_KEY=your_key to your .env file.", file=sys.stderr)
|
||||
print(f"Current environment keys: {list(filter(lambda k: 'key' in k.lower(), os.environ.keys()))}", file=sys.stderr)
|
||||
|
||||
# Always set environment variables, even if empty - the error messages will be handled by the code
|
||||
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
|
||||
os.environ["ANTHROPIC_API_KEY"] = ANTHROPIC_API_KEY
|
||||
os.environ["LLAMA_CLOUD_API_KEY"] = LLAMA_CLOUD_API_KEY
|
||||
|
||||
# Print API key status for debugging
|
||||
print(f"OpenAI API key {'is set' if OPENAI_API_KEY else 'is NOT set'}", file=sys.stderr)
|
||||
|
||||
# --- AI Model Configuration ---
|
||||
LLM_MODEL = "chatgpt-4o-latest" # Or "gpt-4o" etc.
|
||||
EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
LLM_TEMPERATURE = 0.3
|
||||
LLM_TIMEOUT = 300.0 # 5 minutes
|
||||
AGENT_TIMEOUT = 600.0 # 10 minutes for the agent run
|
||||
TOOL_EXECUTION_TIMEOUT = 300.0 # 5 minutes for individual tool calls
|
||||
|
||||
# --- LlamaParse Configuration ---
|
||||
LLAMA_PARSE_VENDOR_MODEL = "openai-gpt4o" # Verify model name
|
||||
LLAMA_PARSE_MAX_TIMEOUT = 3600 # 1 hour
|
||||
|
||||
# --- Indexing Configuration ---
|
||||
# NODE_PARSER_CHUNK_SIZE = 2048 # Example if using SentenceSplitter
|
||||
# NODE_PARSER_CHUNK_OVERLAP = 20
|
||||
# Use Semantic Splitter by default (see ai_core.py)
|
||||
SIMILARITY_TOP_K = 10
|
||||
SIMILARITY_CUTOFF = 0.0 # Adjust if needed
|
||||
|
||||
# --- CORS Configuration ---
|
||||
CORS_ALLOWED_ORIGINS = ["http://localhost:5173", "https://ai-sandbox.oliver.solutions"] # HP chatbot CORS origins
|
||||
CORS_SUPPORTS_CREDENTIALS = True
|
||||
|
||||
# --- Server Configuration ---
|
||||
SERVER_HOST = "0.0.0.0" if os.environ.get("PRODUCTION", "false").lower() == "true" else "localhost"
|
||||
SERVER_PORT = int(os.environ.get("PORT", "8746")) # Port for HP chatbot (unique from Netflix)
|
||||
USE_RELOADER = os.environ.get("PRODUCTION", "false").lower() != "true"
|
||||
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO") # Changed default to INFO
|
||||
|
||||
# Hypercorn specific timeouts (in seconds)
|
||||
KEEP_ALIVE_TIMEOUT = 300
|
||||
READ_TIMEOUT = 300
|
||||
WRITE_TIMEOUT = 300
|
||||
|
||||
# --- MongoDB Configuration ---
|
||||
# Assumes mongodb_utils handles connection details (e.g., via environment variables)
|
||||
|
||||
# --- Neo4j Configuration ---
|
||||
NEO4J_URL = os.environ.get("NEO4J_URL", "bolt://localhost:7688") # Separate Neo4j instance for HP
|
||||
NEO4J_USERNAME = os.environ.get("NEO4J_USERNAME", "neo4j")
|
||||
NEO4J_PASSWORD = os.environ.get("NEO4J_PASSWORD", "hp-graphrag-2024") # HP-specific password
|
||||
BIN
docs/graphRAG chatbot technical documentation.pdf
Normal file
BIN
docs/graphRAG chatbot technical documentation.pdf
Normal file
Binary file not shown.
666
docs/graphRAG_chatbot_documentation.md
Normal file
666
docs/graphRAG_chatbot_documentation.md
Normal file
|
|
@ -0,0 +1,666 @@
|
|||
# HP GraphRAG Chatbot - Technical Documentation
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [System Overview](#system-overview)
|
||||
2. [Architecture](#architecture)
|
||||
3. [Technology Stack](#technology-stack)
|
||||
4. [Data Flow](#data-flow)
|
||||
5. [Database Design](#database-design)
|
||||
6. [API Reference](#api-reference)
|
||||
7. [User Flow](#user-flow)
|
||||
8. [Security](#security)
|
||||
9. [Deployment](#deployment)
|
||||
10. [Development](#development)
|
||||
11. [Troubleshooting](#troubleshooting)
|
||||
|
||||
---
|
||||
|
||||
## System Overview
|
||||
|
||||
The HP GraphRAG Chatbot is a sophisticated conversational AI system that combines vector search with knowledge graph capabilities to answer questions about HP marketing materials and brand guidelines. It processes multimodal documents (text + images) and uses a hybrid AI agent approach for intelligent information retrieval and response generation.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Multimodal Document Processing**: Extracts text and images from PDFs, PowerPoint, and other marketing documents
|
||||
- **GraphRAG Architecture**: Combines vector similarity search with knowledge graph community detection
|
||||
- **Custom ReAct Agent**: Implements reasoning and action patterns for intelligent query processing
|
||||
- **Session Management**: Maintains conversation context across multiple interactions
|
||||
- **Image Display**: Shows relevant document screenshots alongside responses
|
||||
- **Authentication**: Microsoft Authentication Library (MSAL) integration
|
||||
- **Conversation History**: Persistent storage and retrieval of chat sessions
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Frontend (React)"
|
||||
FE[Chat Interface]
|
||||
AUTH[MSAL Auth]
|
||||
CONV[Conversation Manager]
|
||||
UI[UI Components]
|
||||
end
|
||||
|
||||
subgraph "Backend (Python/Flask)"
|
||||
API[Flask Routes]
|
||||
AGENT[ReAct Agent]
|
||||
GRAPH[GraphRAG Engine]
|
||||
SESSION[Session Manager]
|
||||
PARSE[Document Parser]
|
||||
end
|
||||
|
||||
subgraph "AI/ML Layer"
|
||||
LLM[OpenAI GPT-4]
|
||||
EMBED[Text Embeddings]
|
||||
LLAMAPARSE[LlamaParse]
|
||||
end
|
||||
|
||||
subgraph "Data Storage"
|
||||
NEO4J[(Neo4j<br/>Knowledge Graph)]
|
||||
MONGO[(MongoDB<br/>Conversations)]
|
||||
VECTOR[(Vector Index<br/>LlamaIndex)]
|
||||
FILES[File Storage<br/>Images/Documents]
|
||||
end
|
||||
|
||||
FE --> API
|
||||
AUTH --> API
|
||||
CONV --> API
|
||||
|
||||
API --> AGENT
|
||||
API --> SESSION
|
||||
API --> PARSE
|
||||
|
||||
AGENT --> GRAPH
|
||||
GRAPH --> NEO4J
|
||||
GRAPH --> VECTOR
|
||||
AGENT --> LLM
|
||||
|
||||
PARSE --> LLAMAPARSE
|
||||
LLAMAPARSE --> FILES
|
||||
PARSE --> EMBED
|
||||
EMBED --> VECTOR
|
||||
|
||||
SESSION --> MONGO
|
||||
|
||||
style FE fill:#e1f5fe
|
||||
style API fill:#f3e5f5
|
||||
style AGENT fill:#e8f5e8
|
||||
style NEO4J fill:#fff3e0
|
||||
style MONGO fill:#f1f8e9
|
||||
```
|
||||
|
||||
### Component Breakdown
|
||||
|
||||
#### Frontend (React)
|
||||
- **Chat Interface**: Main conversational UI with message bubbles, image viewing, and input handling
|
||||
- **Authentication**: MSAL-based Microsoft authentication
|
||||
- **Conversation Manager**: Handles multiple conversation sessions and history
|
||||
- **Theme Toggle**: Dark/light mode support
|
||||
|
||||
#### Backend (Python/Flask)
|
||||
- **Flask Routes**: RESTful API endpoints for chat, authentication, file serving
|
||||
- **ReAct Agent**: Custom implementation with reasoning, action, and observation cycles
|
||||
- **GraphRAG Engine**: Hybrid retrieval combining vector search with graph-based community detection
|
||||
- **Session Manager**: Maps frontend sessions to database conversations
|
||||
- **Document Parser**: LlamaParse integration for multimodal document processing
|
||||
|
||||
#### Data Layer
|
||||
- **Neo4j**: Stores knowledge graph with entities, relationships, and communities
|
||||
- **MongoDB**: Persists user conversations, messages, and session state
|
||||
- **Vector Index**: LlamaIndex-based semantic search capabilities
|
||||
- **File Storage**: Local filesystem for processed images and documents
|
||||
|
||||
---
|
||||
|
||||
## Technology Stack
|
||||
|
||||
### Backend
|
||||
- **Framework**: Flask + Hypercorn (ASGI)
|
||||
- **AI/ML**:
|
||||
- OpenAI GPT-4 (LLM)
|
||||
- text-embedding-3-small (embeddings)
|
||||
- LlamaParse (document processing)
|
||||
- LlamaIndex (vector indexing)
|
||||
- **Databases**:
|
||||
- Neo4j (knowledge graph)
|
||||
- MongoDB (conversations)
|
||||
- **Languages**: Python 3.9+
|
||||
|
||||
### Frontend
|
||||
- **Framework**: React 18 + Vite
|
||||
- **Styling**: TailwindCSS + Shadcn/ui
|
||||
- **Authentication**: Microsoft Authentication Library (MSAL)
|
||||
- **Languages**: JavaScript/JSX
|
||||
|
||||
### Infrastructure
|
||||
- **Web Server**: Hypercorn (ASGI server)
|
||||
- **Containerization**: Docker support
|
||||
- **Deployment**: Azure/Cloud-based
|
||||
|
||||
---
|
||||
|
||||
## Data Flow
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant User
|
||||
participant Frontend
|
||||
participant API
|
||||
participant Agent
|
||||
participant GraphRAG
|
||||
participant Neo4j
|
||||
participant Vector
|
||||
participant OpenAI
|
||||
participant MongoDB
|
||||
|
||||
User->>Frontend: Send message
|
||||
Frontend->>API: POST /chat
|
||||
API->>Agent: Process query
|
||||
|
||||
Agent->>GraphRAG: Retrieve context
|
||||
GraphRAG->>Vector: Vector similarity search
|
||||
GraphRAG->>Neo4j: Community detection
|
||||
GraphRAG->>OpenAI: Generate synthesis
|
||||
GraphRAG->>Agent: Combined context
|
||||
|
||||
Agent->>OpenAI: Generate response
|
||||
OpenAI->>Agent: Response + reasoning
|
||||
|
||||
Agent->>API: Structured response
|
||||
API->>MongoDB: Store conversation
|
||||
API->>Frontend: Response with sources/images
|
||||
Frontend->>User: Display response
|
||||
```
|
||||
|
||||
### Document Processing Flow
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
UPLOAD[Document Upload] --> PARSE[LlamaParse Processing]
|
||||
PARSE --> EXTRACT[Extract Text + Images]
|
||||
EXTRACT --> SPLIT[Semantic Splitting]
|
||||
SPLIT --> EMBED[Generate Embeddings]
|
||||
EMBED --> VECTOR[Store in Vector Index]
|
||||
SPLIT --> GRAPH[Extract Entities/Relations]
|
||||
GRAPH --> NEO4J[Store in Neo4j]
|
||||
EXTRACT --> IMAGES[Save Page Images]
|
||||
IMAGES --> STORAGE[File Storage]
|
||||
NEO4J --> COMMUNITY[Community Detection]
|
||||
COMMUNITY --> CACHE[Cache Communities]
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Database Design
|
||||
|
||||
### Neo4j Knowledge Graph Schema
|
||||
|
||||
```mermaid
|
||||
erDiagram
|
||||
Entity {
|
||||
string name
|
||||
string label
|
||||
string description
|
||||
dict properties
|
||||
}
|
||||
|
||||
Relation {
|
||||
string label
|
||||
string source_id
|
||||
string target_id
|
||||
string description
|
||||
dict properties
|
||||
}
|
||||
|
||||
Community {
|
||||
int community_id
|
||||
text summary
|
||||
list entity_ids
|
||||
}
|
||||
|
||||
Entity ||--o{ Relation : participates_in
|
||||
Community ||--o{ Entity : contains
|
||||
```
|
||||
|
||||
### MongoDB Collections Schema
|
||||
|
||||
```mermaid
|
||||
erDiagram
|
||||
Users {
|
||||
ObjectId _id
|
||||
string username
|
||||
string email
|
||||
datetime created_at
|
||||
datetime last_login
|
||||
}
|
||||
|
||||
Conversations {
|
||||
ObjectId _id
|
||||
string session_id
|
||||
ObjectId user_id
|
||||
string title
|
||||
datetime created_at
|
||||
datetime last_updated
|
||||
boolean is_deleted
|
||||
}
|
||||
|
||||
Messages {
|
||||
ObjectId _id
|
||||
ObjectId conversation_id
|
||||
string role
|
||||
text content
|
||||
array sources
|
||||
array reasoning
|
||||
array images
|
||||
datetime timestamp
|
||||
}
|
||||
|
||||
Users ||--o{ Conversations : owns
|
||||
Conversations ||--o{ Messages : contains
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API Reference
|
||||
|
||||
### Authentication
|
||||
All API endpoints require authentication via `X-MS-USERNAME` header (except in development mode).
|
||||
|
||||
### Core Endpoints
|
||||
|
||||
#### POST /chat
|
||||
Processes chat messages and returns AI responses.
|
||||
|
||||
**Request:**
|
||||
```json
|
||||
{
|
||||
"message": "string",
|
||||
"sessionId": "string"
|
||||
}
|
||||
```
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"data": {
|
||||
"response": "string",
|
||||
"sources": [
|
||||
{
|
||||
"content": "string",
|
||||
"tool_name": "string",
|
||||
"retrieval_method": "vector_only|graphrag_hybrid"
|
||||
}
|
||||
],
|
||||
"reasoning": [
|
||||
{
|
||||
"type": "ActionReasoningStep|ObservationReasoningStep",
|
||||
"action": "string",
|
||||
"observation": "string"
|
||||
}
|
||||
],
|
||||
"images": [
|
||||
{
|
||||
"filename": "string",
|
||||
"document": "string",
|
||||
"page": "number",
|
||||
"url_encoded_filename": "string"
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
#### GET /status
|
||||
Returns system initialization status.
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"global_status": "initialized",
|
||||
"initialized": true,
|
||||
"timestamp": "2024-01-01T00:00:00.000Z"
|
||||
}
|
||||
```
|
||||
|
||||
#### GET /conversations
|
||||
Retrieves user's conversation history.
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"conversations": [
|
||||
{
|
||||
"id": "string",
|
||||
"title": "string",
|
||||
"created_at": "datetime",
|
||||
"last_updated": "datetime",
|
||||
"session_id": "string"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### GET /conversations/{id}/messages
|
||||
Retrieves messages for a specific conversation.
|
||||
|
||||
**Response:**
|
||||
```json
|
||||
{
|
||||
"status": "success",
|
||||
"conversation_title": "string",
|
||||
"messages": [
|
||||
{
|
||||
"id": "string",
|
||||
"role": "user|assistant",
|
||||
"content": "string",
|
||||
"timestamp": "datetime",
|
||||
"sources": [],
|
||||
"reasoning": [],
|
||||
"images": []
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
#### GET /images/{filename}
|
||||
Serves processed document images.
|
||||
|
||||
#### POST /reset
|
||||
Resets the global agent's conversation memory.
|
||||
|
||||
#### DELETE /conversations/{id}
|
||||
Deletes a conversation (soft delete by default).
|
||||
|
||||
---
|
||||
|
||||
## User Flow
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
START([User Access]) --> AUTH{Authenticated?}
|
||||
AUTH -->|No| LOGIN[MSAL Login]
|
||||
LOGIN --> AUTH
|
||||
AUTH -->|Yes| LOAD[Load Conversations]
|
||||
|
||||
LOAD --> NEWCHAT[Create New Chat]
|
||||
NEWCHAT --> INTERFACE[Chat Interface]
|
||||
|
||||
INTERFACE --> INPUT[User Input]
|
||||
INPUT --> PROCESS[Process with GraphRAG]
|
||||
PROCESS --> RETRIEVE[Hybrid Retrieval]
|
||||
RETRIEVE --> GENERATE[Generate Response]
|
||||
GENERATE --> DISPLAY[Display with Images]
|
||||
DISPLAY --> INPUT
|
||||
|
||||
DISPLAY --> SAVE[Save to History]
|
||||
SAVE --> UPDATE[Update Conversation]
|
||||
|
||||
INTERFACE --> HISTORY[View History]
|
||||
HISTORY --> SELECT[Select Conversation]
|
||||
SELECT --> LOAD_MSG[Load Messages]
|
||||
LOAD_MSG --> INTERFACE
|
||||
|
||||
INTERFACE --> EXPORT[Export Brief]
|
||||
INTERFACE --> DELETE[Delete Conversation]
|
||||
```
|
||||
|
||||
### Detailed User Journey
|
||||
|
||||
1. **Authentication**: User logs in via Microsoft MSAL
|
||||
2. **Conversation Creation**: System creates new conversation or loads existing
|
||||
3. **Query Processing**:
|
||||
- User sends message
|
||||
- GraphRAG performs hybrid retrieval
|
||||
- Vector similarity search finds relevant chunks
|
||||
- Knowledge graph identifies related communities
|
||||
- LLM synthesizes response with reasoning
|
||||
4. **Response Display**:
|
||||
- Text response with markdown support
|
||||
- Source attribution tooltips
|
||||
- Relevant document images
|
||||
- Reasoning chain (if available)
|
||||
5. **History Management**: Conversations persisted and retrievable
|
||||
|
||||
---
|
||||
|
||||
## Security
|
||||
|
||||
### Authentication
|
||||
- Microsoft Authentication Library (MSAL) integration
|
||||
- Azure AD tenant-based access control
|
||||
- Session-based user identification
|
||||
|
||||
### Data Protection
|
||||
- No sensitive data logged in plain text
|
||||
- Conversation data encrypted at rest (MongoDB)
|
||||
- API key management via environment variables
|
||||
- CORS configuration for cross-origin requests
|
||||
|
||||
### Access Control
|
||||
- User-scoped conversation access
|
||||
- Session-based authorization
|
||||
- Development vs production mode differentiation
|
||||
|
||||
---
|
||||
|
||||
## Deployment
|
||||
|
||||
### Environment Configuration
|
||||
|
||||
#### Backend (.env)
|
||||
```bash
|
||||
# API Keys
|
||||
OPENAI_API_KEY=your_openai_key
|
||||
LLAMA_CLOUD_API_KEY=your_llama_cloud_key
|
||||
ANTHROPIC_API_KEY=your_anthropic_key
|
||||
|
||||
# Database Configuration
|
||||
NEO4J_URL=bolt://localhost:7688
|
||||
NEO4J_USERNAME=neo4j
|
||||
NEO4J_PASSWORD=hp-graphrag-2024
|
||||
|
||||
# Server Configuration
|
||||
PORT=8746
|
||||
PRODUCTION=true
|
||||
LOG_LEVEL=INFO
|
||||
```
|
||||
|
||||
#### Frontend (.env)
|
||||
```bash
|
||||
VITE_BACKEND_URL=https://ai-sandbox.oliver.solutions/hp_chatbot_back
|
||||
VITE_APP_BASE_URL=/hp_chatbot/
|
||||
```
|
||||
|
||||
### Deployment Steps
|
||||
|
||||
1. **Database Setup**:
|
||||
- Neo4j instance on port 7688
|
||||
- MongoDB with authentication
|
||||
- Initialize collections via `init_mongodb.py`
|
||||
|
||||
2. **Backend Deployment**:
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
python main.py
|
||||
```
|
||||
|
||||
3. **Frontend Build**:
|
||||
```bash
|
||||
cd chat-interface
|
||||
npm install
|
||||
npm run build
|
||||
# Deploy dist/ contents to /hp_chatbot/ path
|
||||
```
|
||||
|
||||
4. **Web Server Configuration**:
|
||||
- Configure reverse proxy (nginx/Apache)
|
||||
- Set up SSL certificates
|
||||
- Configure CORS origins
|
||||
|
||||
---
|
||||
|
||||
## Development
|
||||
|
||||
### Backend Development
|
||||
|
||||
```bash
|
||||
# Setup virtual environment
|
||||
python -m venv env
|
||||
source env/bin/activate # or env\Scripts\activate on Windows
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Start development server
|
||||
python main.py
|
||||
```
|
||||
|
||||
### Frontend Development
|
||||
|
||||
```bash
|
||||
cd chat-interface
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
|
||||
### Key Development Commands
|
||||
|
||||
| Command | Purpose |
|
||||
|---------|---------|
|
||||
| `python main.py` | Start backend server |
|
||||
| `npm run dev` | Start frontend dev server |
|
||||
| `npm run build` | Build frontend for production |
|
||||
| `npm run lint` | Lint frontend code |
|
||||
|
||||
### Code Structure
|
||||
|
||||
```
|
||||
hp_graphRAG_bot/
|
||||
├── Backend (Python)
|
||||
│ ├── main.py # Application entry point
|
||||
│ ├── ai_core.py # Core AI engine & ReAct agent
|
||||
│ ├── graph_rag_integration.py # GraphRAG system
|
||||
│ ├── routes.py # Flask API routes
|
||||
│ ├── session_manager.py # Session management
|
||||
│ ├── mongodb_utils.py # MongoDB operations
|
||||
│ ├── config.py # Configuration
|
||||
│ └── shared_state.py # Global state management
|
||||
├── Frontend (React)
|
||||
│ └── chat-interface/
|
||||
│ ├── src/
|
||||
│ │ ├── App.jsx # Main application component
|
||||
│ │ ├── components/ # React components
|
||||
│ │ ├── auth.js # MSAL authentication
|
||||
│ │ └── lib/ # Utilities
|
||||
│ └── dist/ # Production build
|
||||
└── Data Storage
|
||||
├── uploads/images/ # Processed document images
|
||||
├── index_storage/ # Vector index data
|
||||
└── supporting_files/ # Source documents
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
#### Backend Issues
|
||||
|
||||
**Problem**: `Global workflow agent not initialized`
|
||||
**Solution**: Check OpenAI API key and Neo4j connectivity
|
||||
```bash
|
||||
# Verify environment variables
|
||||
echo $OPENAI_API_KEY
|
||||
# Check Neo4j connection
|
||||
curl http://localhost:7474
|
||||
```
|
||||
|
||||
**Problem**: `LlamaParse timeout during document processing`
|
||||
**Solution**: Increase timeout settings in config.py
|
||||
```python
|
||||
LLAMA_PARSE_MAX_TIMEOUT = 7200 # 2 hours
|
||||
```
|
||||
|
||||
**Problem**: `MongoDB connection failed`
|
||||
**Solution**: Verify MongoDB service and credentials
|
||||
```bash
|
||||
# Check MongoDB status
|
||||
brew services list | grep mongodb
|
||||
# Test connection
|
||||
mongosh mongodb://hp:hp@localhost:27017/hp_chatbot
|
||||
```
|
||||
|
||||
#### Frontend Issues
|
||||
|
||||
**Problem**: `CORS policy blocking requests`
|
||||
**Solution**: Update CORS_ALLOWED_ORIGINS in backend config.py
|
||||
|
||||
**Problem**: `Authentication failures`
|
||||
**Solution**: Verify MSAL configuration and Azure AD settings
|
||||
|
||||
**Problem**: `Images not loading`
|
||||
**Solution**: Check image file paths and backend /images/ endpoint
|
||||
|
||||
### Debug Endpoints
|
||||
|
||||
**Development Mode Only:**
|
||||
- `GET /debug-status` - System state inspection
|
||||
- `POST /reinitialize` - Force agent reinitialization
|
||||
- `POST /capture-screenshot` - Manual image extraction testing
|
||||
|
||||
### Logging
|
||||
|
||||
All components use structured logging:
|
||||
```python
|
||||
log_structured('info', 'Event description', {'key': 'value'})
|
||||
```
|
||||
|
||||
Log files locations:
|
||||
- Backend: `app.log`
|
||||
- MongoDB operations: `mongodb.log`
|
||||
|
||||
---
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Scaling
|
||||
- **Vector Index**: Consider PostgreSQL pgvector for large deployments
|
||||
- **Neo4j**: Implement read replicas for query scaling
|
||||
- **MongoDB**: Use connection pooling and sharding
|
||||
- **Caching**: Redis for session and community caches
|
||||
|
||||
### Optimization
|
||||
- **GraphRAG Communities**: Pre-computed and cached
|
||||
- **Image Processing**: Async processing with queue system
|
||||
- **Memory Management**: Agent memory reset policies
|
||||
- **Response Time**: Parallel vector and graph retrieval
|
||||
|
||||
---
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
### Planned Features
|
||||
1. **Multi-tenant Architecture**: Support multiple organizations
|
||||
2. **Advanced Analytics**: Usage metrics and conversation insights
|
||||
3. **Enhanced Multimodal**: Video and audio processing
|
||||
4. **Real-time Collaboration**: Multi-user conversations
|
||||
5. **API Extensions**: Webhook integrations and external tool calling
|
||||
6. **Advanced Security**: Role-based access control and audit logging
|
||||
|
||||
### Technical Debt
|
||||
- Implement comprehensive test suite
|
||||
- Add API rate limiting
|
||||
- Improve error handling consistency
|
||||
- Optimize database queries
|
||||
- Add health check endpoints
|
||||
|
||||
---
|
||||
|
||||
*Documentation Version: 1.0*
|
||||
*Last Updated: 2024-01-01*
|
||||
*System Version: HP GraphRAG Chatbot v1.0*
|
||||
339
document_generator.py
Normal file
339
document_generator.py
Normal file
|
|
@ -0,0 +1,339 @@
|
|||
# hp_chatbot/document_generator.py
|
||||
|
||||
import io
|
||||
import re
|
||||
import markdown2
|
||||
from bs4 import BeautifulSoup
|
||||
from docx import Document
|
||||
from docx.shared import Inches, Pt, RGBColor
|
||||
from docx.enum.text import WD_ALIGN_PARAGRAPH, WD_BREAK
|
||||
from docx.oxml.shared import OxmlElement, qn
|
||||
from docx.oxml import parse_xml
|
||||
|
||||
from utils import log_structured
|
||||
|
||||
# --- Helper for Horizontal Line ---
|
||||
def add_horizontal_line(paragraph):
|
||||
"""Adds a horizontal line after the specified paragraph."""
|
||||
p = paragraph._p # the lxml element beneath the paragraph
|
||||
pPr = p.get_or_add_pPr() # Get or add paragraph properties element
|
||||
pBdr = OxmlElement('w:pBdr') # Create paragraph border element
|
||||
# Add a bottom border
|
||||
bottom_bdr = OxmlElement('w:bottom')
|
||||
bottom_bdr.set(qn('w:val'), 'single') # Border style
|
||||
bottom_bdr.set(qn('w:sz'), '6') # Border size (in 1/8 points)
|
||||
bottom_bdr.set(qn('w:space'), '1') # Space between text and border
|
||||
bottom_bdr.set(qn('w:color'), 'auto') # Border color
|
||||
pBdr.append(bottom_bdr)
|
||||
pPr.append(pBdr)
|
||||
|
||||
# --- Inline Markdown to DOCX Run Formatting ---
|
||||
def process_inline_formatting(paragraph, text):
|
||||
"""
|
||||
Processes simple inline markdown (bold, italic, code) within text
|
||||
and adds formatted runs to the paragraph.
|
||||
Handles nested formatting cautiously.
|
||||
"""
|
||||
# Regex to find **bold**, *italic*, _italic_, `code` segments
|
||||
# It captures the marker and the content separately.
|
||||
pattern = r'(\*\*|`|\*|_)(.*?)(\1)'
|
||||
last_end = 0
|
||||
|
||||
for match in re.finditer(pattern, text):
|
||||
start, end = match.span()
|
||||
marker = match.group(1)
|
||||
content = match.group(2)
|
||||
|
||||
# Add preceding text if any
|
||||
if start > last_end:
|
||||
paragraph.add_run(text[last_end:start])
|
||||
|
||||
# Add formatted run
|
||||
run = paragraph.add_run(content)
|
||||
if marker == '**':
|
||||
run.bold = True
|
||||
elif marker == '*' or marker == '_':
|
||||
run.italic = True
|
||||
elif marker == '`':
|
||||
run.font.name = 'Courier New'
|
||||
# run.font.size = Pt(10) # Optional: Set size for code
|
||||
|
||||
last_end = end
|
||||
|
||||
# Add any remaining text after the last match
|
||||
if last_end < len(text):
|
||||
paragraph.add_run(text[last_end:])
|
||||
|
||||
|
||||
# --- HTML to DOCX Conversion ---
|
||||
def convert_html_to_docx(doc: Document, html_content: str):
|
||||
"""
|
||||
Converts basic HTML content (from markdown conversion) to Word elements.
|
||||
Handles common tags like paragraphs, headings, lists, bold, italic, code.
|
||||
"""
|
||||
# Pre-process HTML slightly for cleaner parsing
|
||||
html_content = re.sub(r'\s*\n\s*', '\n', html_content).strip() # Normalize whitespace
|
||||
html_content = f"<body>{html_content}</body>" # Wrap in body for better parsing
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
|
||||
# Recursive function to handle elements
|
||||
def process_element(element, current_paragraph=None, current_style=None, in_list=False):
|
||||
# Skip NavigableString if it's just whitespace or newline outside pre
|
||||
if isinstance(element, str):
|
||||
text = str(element).strip('\n') # Keep internal spaces, strip leading/trailing newlines
|
||||
if text: # Only add if there's actual content
|
||||
if current_paragraph:
|
||||
run = current_paragraph.add_run(text)
|
||||
if current_style:
|
||||
if 'bold' in current_style: run.bold = True
|
||||
if 'italic' in current_style: run.italic = True
|
||||
if 'code' in current_style: run.font.name = 'Courier New'
|
||||
else:
|
||||
# Text outside paragraph usually means an error or whitespace
|
||||
# log_structured('debug', f"Orphan text node found: '{text[:50]}...'")
|
||||
pass # Or create a default paragraph: doc.add_paragraph(text)
|
||||
return
|
||||
|
||||
# --- Block Level Elements ---
|
||||
if element.name in ['p', 'div']:
|
||||
# Avoid creating paragraphs for empty containers unless they contain <br>
|
||||
text_content = element.get_text(strip=True)
|
||||
has_br = element.find('br')
|
||||
if text_content or has_br:
|
||||
para = doc.add_paragraph()
|
||||
# Apply list indentation if necessary (though lists handle their own paras)
|
||||
# if in_list: para.paragraph_format.left_indent = Inches(0.5)
|
||||
new_style = current_style.copy() if current_style else set()
|
||||
for child in element.children:
|
||||
process_element(child, para, new_style, in_list)
|
||||
# else: skip empty p/div
|
||||
|
||||
elif element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
|
||||
try:
|
||||
level = int(element.name[1])
|
||||
heading = doc.add_heading(level=level)
|
||||
# Process children for inline formatting within heading
|
||||
new_style = current_style.copy() if current_style else set()
|
||||
for child in element.children:
|
||||
process_element(child, heading, new_style, in_list)
|
||||
# If no children processed (just text), add it directly
|
||||
if not heading.runs:
|
||||
heading.add_run(element.get_text(strip=True))
|
||||
except ValueError: pass # Should not happen with h1-h6
|
||||
|
||||
elif element.name == 'ul':
|
||||
for li in element.find_all('li', recursive=False):
|
||||
# Each li gets its own paragraph with bullet style
|
||||
para = doc.add_paragraph(style='List Bullet')
|
||||
new_style = current_style.copy() if current_style else set()
|
||||
for child in li.children:
|
||||
process_element(child, para, new_style, in_list=True)
|
||||
# If li was empty or only contained whitespace
|
||||
if not para.text.strip():
|
||||
para.text = "" # Ensure empty bullet point exists
|
||||
|
||||
elif element.name == 'ol':
|
||||
# Numbering is handled by the 'List Number' style
|
||||
for li in element.find_all('li', recursive=False):
|
||||
para = doc.add_paragraph(style='List Number')
|
||||
new_style = current_style.copy() if current_style else set()
|
||||
for child in li.children:
|
||||
process_element(child, para, new_style, in_list=True)
|
||||
if not para.text.strip():
|
||||
para.text = "" # Ensure empty numbered item exists
|
||||
|
||||
# Note: 'li' is handled within 'ul'/'ol' processing.
|
||||
|
||||
elif element.name == 'pre':
|
||||
# Often contains a 'code' element, handle that
|
||||
code_tag = element.find('code')
|
||||
content = code_tag.get_text() if code_tag else element.get_text()
|
||||
if content.strip():
|
||||
para = doc.add_paragraph(style='CodeStyle') # Requires 'CodeStyle' to be defined
|
||||
# Preserve whitespace more carefully for <pre>
|
||||
run = para.add_run(content.strip('\n')) # Strip outer newlines only
|
||||
run.font.name = 'Courier New'
|
||||
# run.font.size = Pt(10)
|
||||
|
||||
elif element.name == 'blockquote':
|
||||
para = doc.add_paragraph(style='Quote') # Requires 'Quote' style
|
||||
new_style = current_style.copy() if current_style else set()
|
||||
for child in element.children:
|
||||
process_element(child, para, new_style, in_list)
|
||||
|
||||
elif element.name == 'hr':
|
||||
para = doc.add_paragraph()
|
||||
add_horizontal_line(para)
|
||||
|
||||
elif element.name == 'br':
|
||||
if current_paragraph:
|
||||
current_paragraph.add_run().add_break() # Add line break within paragraph
|
||||
|
||||
|
||||
# --- Inline Elements ---
|
||||
elif element.name in ['strong', 'b']:
|
||||
new_style = current_style.copy() if current_style else set()
|
||||
new_style.add('bold')
|
||||
for child in element.children:
|
||||
process_element(child, current_paragraph, new_style, in_list)
|
||||
|
||||
elif element.name in ['em', 'i']:
|
||||
new_style = current_style.copy() if current_style else set()
|
||||
new_style.add('italic')
|
||||
for child in element.children:
|
||||
process_element(child, current_paragraph, new_style, in_list)
|
||||
|
||||
elif element.name == 'code':
|
||||
# Handle inline code - assumes it's within a paragraph already
|
||||
if current_paragraph:
|
||||
text = element.get_text()
|
||||
if text:
|
||||
run = current_paragraph.add_run(text)
|
||||
run.font.name = 'Courier New'
|
||||
# Add specific inline code style if desired
|
||||
else:
|
||||
# Code tag not within a paragraph? Create one.
|
||||
para = doc.add_paragraph(style='CodeStyle')
|
||||
run = para.add_run(element.get_text())
|
||||
run.font.name = 'Courier New'
|
||||
|
||||
|
||||
elif element.name == 'a':
|
||||
# Add hyperlink if possible, otherwise just text
|
||||
text = element.get_text(strip=True)
|
||||
href = element.get('href')
|
||||
if current_paragraph and text:
|
||||
# python-docx doesn't have direct hyperlink support easily added here.
|
||||
# Simplest: add text with underline and blue color.
|
||||
run = current_paragraph.add_run(text)
|
||||
run.underline = True
|
||||
run.font.color.rgb = RGBColor(0x05, 0x63, 0xC1) # Standard link blue
|
||||
# For actual hyperlinks, more complex XML manipulation is needed.
|
||||
|
||||
|
||||
# --- Body/Other Tags: Process children ---
|
||||
elif element.name in ['body', 'span', 'div']: # Treat span/div mostly as containers
|
||||
new_style = current_style.copy() if current_style else set()
|
||||
for child in element.children:
|
||||
process_element(child, current_paragraph, new_style, in_list)
|
||||
|
||||
# --- Ignored Tags ---
|
||||
elif element.name in ['script', 'style', 'head', 'meta', 'title']:
|
||||
pass # Ignore these tags and their content
|
||||
|
||||
else:
|
||||
# Unknown tag: try to process its children if it's a container,
|
||||
# or add its text content if it's inline-like.
|
||||
log_structured('warning', f"Unhandled HTML tag encountered: <{element.name}>", {'content_preview': element.get_text(strip=True)[:50]})
|
||||
# Default behavior: process children recursively
|
||||
new_style = current_style.copy() if current_style else set()
|
||||
for child in element.children:
|
||||
process_element(child, current_paragraph, new_style, in_list)
|
||||
|
||||
|
||||
# Start processing from the top-level elements within the parsed body
|
||||
body = soup.find('body')
|
||||
if body:
|
||||
for element in body.children:
|
||||
process_element(element)
|
||||
|
||||
|
||||
# --- Main Markdown to DOCX Function ---
|
||||
def create_brief_docx(brief_content_markdown: str) -> io.BytesIO:
|
||||
"""
|
||||
Creates a Word document (.docx) in memory from markdown content.
|
||||
|
||||
Args:
|
||||
brief_content_markdown: The markdown string content.
|
||||
|
||||
Returns:
|
||||
An io.BytesIO buffer containing the Word document.
|
||||
"""
|
||||
doc = Document()
|
||||
|
||||
# --- Define Styles (Optional but recommended) ---
|
||||
styles = doc.styles
|
||||
# Normal style
|
||||
style = styles['Normal']
|
||||
font = style.font
|
||||
font.name = 'Calibri' # Or HP specific font if available
|
||||
font.size = Pt(11)
|
||||
|
||||
# Code style (example)
|
||||
try:
|
||||
code_style = styles.add_style('CodeStyle', 1) # 1 for paragraph style
|
||||
code_style.font.name = 'Courier New'
|
||||
code_style.font.size = Pt(10)
|
||||
# Prevent spell check for code blocks
|
||||
code_style.element.rPr.rFonts.set(qn('w:ascii'), 'Courier New')
|
||||
code_style.element.rPr.rFonts.set(qn('w:hAnsi'), 'Courier New')
|
||||
# code_style.element.xpath('./w:rPr/w:lang')[0].set(qn('w:noProof'), '1') # Requires lxml maybe
|
||||
p_fmt = code_style.paragraph_format
|
||||
p_fmt.space_before = Pt(6)
|
||||
p_fmt.space_after = Pt(6)
|
||||
except ValueError:
|
||||
log_structured('warning', "'CodeStyle' already exists. Using existing.")
|
||||
code_style = styles['CodeStyle'] # Use existing if it fails to add
|
||||
|
||||
# Quote style (example)
|
||||
try:
|
||||
quote_style = styles.add_style('QuoteStyle', 1)
|
||||
quote_style.font.italic = True
|
||||
quote_style.paragraph_format.left_indent = Inches(0.5)
|
||||
quote_style.paragraph_format.space_before = Pt(6)
|
||||
quote_style.paragraph_format.space_after = Pt(6)
|
||||
except ValueError:
|
||||
log_structured('warning', "'QuoteStyle' already exists. Using existing.")
|
||||
quote_style = styles['QuoteStyle']
|
||||
|
||||
|
||||
# --- Document Header ---
|
||||
title = doc.add_heading('Marketing Brief', 0)
|
||||
title.alignment = WD_ALIGN_PARAGRAPH.CENTER
|
||||
|
||||
date_para = doc.add_paragraph()
|
||||
date_para.alignment = WD_ALIGN_PARAGRAPH.RIGHT
|
||||
date_run = date_para.add_run(datetime.now().strftime("%B %d, %Y"))
|
||||
date_run.italic = True
|
||||
# Add some space after date
|
||||
date_para.paragraph_format.space_after = Pt(12)
|
||||
|
||||
# Add a horizontal line separator
|
||||
hr_para = doc.add_paragraph()
|
||||
add_horizontal_line(hr_para)
|
||||
hr_para.paragraph_format.space_after = Pt(18) # Space after the line
|
||||
|
||||
|
||||
# --- Convert Markdown to HTML ---
|
||||
# Using markdown2 with recommended extras for broad compatibility
|
||||
extras = [
|
||||
"tables", "fenced-code-blocks", "header-ids", "footnotes",
|
||||
"task_list", "code-friendly", "cuddled-lists", "markdown-in-html",
|
||||
"strike", "spoiler", "target-blank-links", "smarty-pants" # Added smarty-pants
|
||||
]
|
||||
html_content = markdown2.markdown(brief_content_markdown, extras=extras)
|
||||
|
||||
log_structured('debug', 'Converted markdown to HTML for DOCX generation', {
|
||||
'md_preview': brief_content_markdown[:200],
|
||||
'html_preview': html_content[:300]
|
||||
})
|
||||
|
||||
# --- Convert HTML to Word Document Elements ---
|
||||
try:
|
||||
convert_html_to_docx(doc, html_content)
|
||||
except Exception as conversion_err:
|
||||
log_structured('error', "Error during HTML to DOCX conversion", {
|
||||
'error': str(conversion_err),
|
||||
'traceback': traceback.format_exc()
|
||||
})
|
||||
# Add error message to the document itself
|
||||
doc.add_paragraph("Error: Could not fully convert content from HTML to DOCX.", style='Emphasis')
|
||||
doc.add_paragraph(str(conversion_err))
|
||||
|
||||
|
||||
# --- Save to Buffer ---
|
||||
doc_buffer = io.BytesIO()
|
||||
doc.save(doc_buffer)
|
||||
doc_buffer.seek(0)
|
||||
|
||||
return doc_buffer
|
||||
581
graphRAG.py
Normal file
581
graphRAG.py
Normal file
|
|
@ -0,0 +1,581 @@
|
|||
import os
|
||||
import json
|
||||
import re
|
||||
import asyncio
|
||||
import networkx as nx
|
||||
from collections import defaultdict
|
||||
from typing import Any, List, Callable, Optional, Union, Dict
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# Import LlamaIndex components
|
||||
from llama_index.core import Document, Settings
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core import PropertyGraphIndex
|
||||
from llama_index.core.async_utils import run_jobs
|
||||
from llama_index.core.indices.property_graph.utils import default_parse_triplets_fn
|
||||
from llama_index.core.graph_stores.types import EntityNode, KG_NODES_KEY, KG_RELATIONS_KEY, Relation
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.prompts.default_prompts import DEFAULT_KG_TRIPLET_EXTRACT_PROMPT
|
||||
from llama_index.core.schema import TransformComponent, BaseNode
|
||||
from llama_index.core.query_engine import CustomQueryEngine
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
# Community detection (using NetworkX instead of graspologic as a fallback)
|
||||
try:
|
||||
from community import best_partition # python-louvain package
|
||||
except ImportError:
|
||||
print("Community detection package not found, using NetworkX built-in community detection")
|
||||
|
||||
# Load environment variables from .env file
|
||||
load_dotenv()
|
||||
|
||||
# Use API key from environment as fallback
|
||||
if not os.environ.get("OPENAI_API_KEY"):
|
||||
os.environ["OPENAI_API_KEY"] = "sk-proj-wXcoIn81Vwg4Iaw0vhmYT3BlbkFJmt1eOxeEAF1juUfhzMtk"
|
||||
|
||||
# Define the GraphRAGExtractor class
|
||||
class GraphRAGExtractor(TransformComponent):
|
||||
"""Extract triples from a graph.
|
||||
|
||||
Uses an LLM and a simple prompt + output parsing to
|
||||
extract paths (i.e. triples) and entity, relation descriptions
|
||||
from text.
|
||||
|
||||
Args:
|
||||
llm (LLM):
|
||||
The language model to use.
|
||||
extract_prompt (Union[str, PromptTemplate]):
|
||||
The prompt to use for extracting triples.
|
||||
parse_fn (callable):
|
||||
A function to parse the output of the language
|
||||
model.
|
||||
num_workers (int):
|
||||
The number of workers to use for parallel
|
||||
processing.
|
||||
max_paths_per_chunk (int):
|
||||
The maximum number of paths to extract per chunk.
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
extract_prompt: PromptTemplate
|
||||
parse_fn: Callable
|
||||
num_workers: int
|
||||
max_paths_per_chunk: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: Optional[LLM] = None,
|
||||
extract_prompt: Optional[Union[str, PromptTemplate]] = None,
|
||||
parse_fn: Callable = default_parse_triplets_fn,
|
||||
max_paths_per_chunk: int = 10,
|
||||
num_workers: int = 4,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
from llama_index.core import Settings
|
||||
|
||||
if isinstance(extract_prompt, str):
|
||||
extract_prompt = PromptTemplate(extract_prompt)
|
||||
|
||||
super().__init__(
|
||||
llm=llm or Settings.llm,
|
||||
extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
|
||||
parse_fn=parse_fn,
|
||||
num_workers=num_workers,
|
||||
max_paths_per_chunk=max_paths_per_chunk,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "GraphExtractor"
|
||||
|
||||
def __call__(
|
||||
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
|
||||
) -> List[BaseNode]:
|
||||
"""Extract triples from nodes."""
|
||||
return asyncio.run(
|
||||
self.acall(nodes, show_progress=show_progress, **kwargs)
|
||||
)
|
||||
|
||||
async def _aextract(self, node: BaseNode) -> BaseNode:
|
||||
"""Extract triples from a node."""
|
||||
assert hasattr(node, "text")
|
||||
|
||||
text = node.get_content(metadata_mode="llm")
|
||||
try:
|
||||
llm_response = await self.llm.apredict(
|
||||
self.extract_prompt,
|
||||
text=text,
|
||||
max_knowledge_triplets=self.max_paths_per_chunk,
|
||||
)
|
||||
entities, entities_relationship = self.parse_fn(llm_response)
|
||||
except ValueError:
|
||||
entities = []
|
||||
entities_relationship = []
|
||||
|
||||
existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
|
||||
existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
|
||||
|
||||
entity_metadata = node.metadata.copy()
|
||||
for entity, entity_type, description in entities:
|
||||
entity_metadata["entity_description"] = description
|
||||
entity_node = EntityNode(
|
||||
name=entity, label=entity_type,
|
||||
properties=entity_metadata
|
||||
)
|
||||
existing_nodes.append(entity_node)
|
||||
|
||||
relation_metadata = node.metadata.copy()
|
||||
for triple in entities_relationship:
|
||||
subj, obj, rel, description = triple
|
||||
relation_metadata["relationship_description"] = description
|
||||
rel_node = Relation(
|
||||
label=rel,
|
||||
source_id=subj,
|
||||
target_id=obj,
|
||||
properties=relation_metadata,
|
||||
)
|
||||
existing_relations.append(rel_node)
|
||||
|
||||
node.metadata[KG_NODES_KEY] = existing_nodes
|
||||
node.metadata[KG_RELATIONS_KEY] = existing_relations
|
||||
return node
|
||||
|
||||
async def acall(
|
||||
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
|
||||
) -> List[BaseNode]:
|
||||
"""Extract triples from nodes async."""
|
||||
jobs = []
|
||||
for node in nodes:
|
||||
jobs.append(self._aextract(node))
|
||||
|
||||
return await run_jobs(
|
||||
jobs,
|
||||
workers=self.num_workers,
|
||||
show_progress=show_progress,
|
||||
desc="Extracting paths from text",
|
||||
)
|
||||
|
||||
# Define the GraphRAGStore class
|
||||
class GraphRAGStore(Neo4jPropertyGraphStore):
|
||||
community_summary = {}
|
||||
entity_info = None
|
||||
max_cluster_size = 5
|
||||
|
||||
def generate_community_summary(self, text):
|
||||
"""Generate summary for a given text using an LLM."""
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role="system",
|
||||
content=(
|
||||
"You are provided with a set of "
|
||||
"relationships from a knowledge graph, each represented as "
|
||||
"entity1->entity2->relation-"
|
||||
">relationship_description. Your task is to create a summary of "
|
||||
"these "
|
||||
"relationships. The summary should include "
|
||||
"the names of the entities involved and a concise synthesis "
|
||||
"of the relationship descriptions. The "
|
||||
"goal is to capture the most critical and relevant details that "
|
||||
"highlight the nature and significance of "
|
||||
"each relationship. Ensure that the summary is coherent and "
|
||||
"integrates the information in a way that "
|
||||
"emphasizes the key aspects of the relationships."
|
||||
),
|
||||
),
|
||||
ChatMessage(role="user", content=text),
|
||||
]
|
||||
response = OpenAI().chat(messages)
|
||||
clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
|
||||
return clean_response
|
||||
|
||||
def build_communities(self):
|
||||
"""Builds communities from the graph and summarizes them."""
|
||||
nx_graph = self._create_nx_graph()
|
||||
|
||||
# Use either Leiden algorithm (from graspologic) or an alternative
|
||||
try:
|
||||
from graspologic.partition import hierarchical_leiden
|
||||
community_hierarchical_clusters = hierarchical_leiden(
|
||||
nx_graph, max_cluster_size=self.max_cluster_size
|
||||
)
|
||||
self.entity_info, community_info = self._collect_community_info(
|
||||
nx_graph, community_hierarchical_clusters
|
||||
)
|
||||
except ImportError:
|
||||
# Fallback to community detection using NetworkX or python-louvain
|
||||
try:
|
||||
from community import best_partition
|
||||
partition = best_partition(nx_graph)
|
||||
# Reformat partition data to expected structure
|
||||
clusters = []
|
||||
for node, cluster_id in partition.items():
|
||||
class Cluster:
|
||||
def __init__(self, node, cluster):
|
||||
self.node = node
|
||||
self.cluster = cluster
|
||||
clusters.append(Cluster(node, cluster_id))
|
||||
self.entity_info, community_info = self._collect_community_info(
|
||||
nx_graph, clusters
|
||||
)
|
||||
except ImportError:
|
||||
# Use NetworkX's built-in community detection
|
||||
from networkx.algorithms import community
|
||||
communities = community.greedy_modularity_communities(nx_graph)
|
||||
clusters = []
|
||||
for i, comm in enumerate(communities):
|
||||
for node in comm:
|
||||
class Cluster:
|
||||
def __init__(self, node, cluster):
|
||||
self.node = node
|
||||
self.cluster = cluster
|
||||
clusters.append(Cluster(node, i))
|
||||
self.entity_info, community_info = self._collect_community_info(
|
||||
nx_graph, clusters
|
||||
)
|
||||
|
||||
self._summarize_communities(community_info)
|
||||
|
||||
def _create_nx_graph(self):
|
||||
"""Converts internal graph representation to NetworkX graph."""
|
||||
nx_graph = nx.Graph()
|
||||
triplets = self.get_triplets()
|
||||
for entity1, relation, entity2 in triplets:
|
||||
nx_graph.add_node(entity1.name)
|
||||
nx_graph.add_node(entity2.name)
|
||||
nx_graph.add_edge(
|
||||
relation.source_id,
|
||||
relation.target_id,
|
||||
relationship=relation.label,
|
||||
description=relation.properties.get("relationship_description", "No description provided"),
|
||||
)
|
||||
return nx_graph
|
||||
|
||||
def _collect_community_info(self, nx_graph, clusters):
|
||||
"""
|
||||
Collect information for each node based on their community,
|
||||
allowing entities to belong to multiple clusters.
|
||||
"""
|
||||
entity_info = defaultdict(set)
|
||||
community_info = defaultdict(list)
|
||||
|
||||
for item in clusters:
|
||||
node = item.node
|
||||
cluster_id = item.cluster
|
||||
|
||||
# Update entity_info
|
||||
entity_info[node].add(cluster_id)
|
||||
|
||||
for neighbor in nx_graph.neighbors(node):
|
||||
edge_data = nx_graph.get_edge_data(node, neighbor)
|
||||
if edge_data:
|
||||
detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
|
||||
community_info[cluster_id].append(detail)
|
||||
|
||||
# Convert sets to lists for easier serialization if needed
|
||||
entity_info = {k: list(v) for k, v in entity_info.items()}
|
||||
|
||||
return dict(entity_info), dict(community_info)
|
||||
|
||||
def _summarize_communities(self, community_info):
|
||||
"""Generate and store summaries for each community."""
|
||||
for community_id, details in community_info.items():
|
||||
details_text = "\n".join(details) + "." # Ensure it ends with a period
|
||||
self.community_summary[community_id] = self.generate_community_summary(details_text)
|
||||
|
||||
def get_community_summaries(self):
|
||||
"""Returns the community summaries, building them if not already done."""
|
||||
if not self.community_summary:
|
||||
self.build_communities()
|
||||
return self.community_summary
|
||||
|
||||
# Define the GraphRAGQueryEngine class
|
||||
class GraphRAGQueryEngine(CustomQueryEngine):
|
||||
graph_store: Union[GraphRAGStore, Any] # Accept any type of graph store
|
||||
index: PropertyGraphIndex
|
||||
llm: LLM
|
||||
similarity_top_k: int = 20
|
||||
|
||||
def custom_query(self, query_str: str) -> str:
|
||||
"""Process query using either community-based approach or direct retrieval."""
|
||||
# Check if we're using GraphRAGStore with communities or SimplePropertyGraphStore
|
||||
if hasattr(self.graph_store, 'get_community_summaries'):
|
||||
# GraphRAG approach with communities
|
||||
entities = self.get_entities(query_str, self.similarity_top_k)
|
||||
|
||||
community_ids = self.retrieve_entity_communities(
|
||||
self.graph_store.entity_info, entities
|
||||
)
|
||||
community_summaries = self.graph_store.get_community_summaries()
|
||||
community_answers = [
|
||||
self.generate_answer_from_summary(community_summary, query_str)
|
||||
for id, community_summary in community_summaries.items()
|
||||
if id in community_ids
|
||||
]
|
||||
|
||||
final_answer = self.aggregate_answers(community_answers)
|
||||
return final_answer
|
||||
else:
|
||||
# Simple approach for SimplePropertyGraphStore
|
||||
# Just get relevant nodes and generate answer
|
||||
nodes = self.index.as_retriever(
|
||||
similarity_top_k=self.similarity_top_k
|
||||
).retrieve(query_str)
|
||||
|
||||
if not nodes:
|
||||
return "I couldn't find any relevant information to answer your question."
|
||||
|
||||
# Combine text from all retrieved nodes
|
||||
context = "\n\n".join([node.get_content() for node in nodes])
|
||||
|
||||
# Generate answer using the LLM
|
||||
prompt = f"Based on the following information, please answer this question: {query_str}\n\nInformation:\n{context}"
|
||||
messages = [
|
||||
ChatMessage(role="system", content=prompt),
|
||||
ChatMessage(role="user", content="Please provide a comprehensive answer based on the information provided.")
|
||||
]
|
||||
response = self.llm.chat(messages)
|
||||
return str(response).strip()
|
||||
|
||||
def get_entities(self, query_str, similarity_top_k):
|
||||
nodes_retrieved = self.index.as_retriever(
|
||||
similarity_top_k=similarity_top_k
|
||||
).retrieve(query_str)
|
||||
|
||||
entities = set()
|
||||
pattern = r"^(\w+(?:\s+\w+)*)\s*->\s*([a-zA-Z\s]+?)\s*->\s*(\w+(?:\s+\w+)*)$"
|
||||
|
||||
for node in nodes_retrieved:
|
||||
matches = re.findall(
|
||||
pattern, node.text, re.MULTILINE | re.IGNORECASE
|
||||
)
|
||||
for match in matches:
|
||||
subject = match[0]
|
||||
obj = match[2]
|
||||
entities.add(subject)
|
||||
entities.add(obj)
|
||||
|
||||
return list(entities)
|
||||
|
||||
def retrieve_entity_communities(self, entity_info, entities):
|
||||
"""
|
||||
Retrieve cluster information for given entities,
|
||||
allowing for multiple clusters per entity.
|
||||
|
||||
Args:
|
||||
entity_info (dict): Dictionary mapping entities to their cluster IDs (list).
|
||||
entities (list): List of entity names to retrieve information for.
|
||||
|
||||
Returns:
|
||||
List of community or cluster IDs to which an entity belongs.
|
||||
"""
|
||||
community_ids = []
|
||||
|
||||
for entity in entities:
|
||||
if entity in entity_info:
|
||||
community_ids.extend(entity_info[entity])
|
||||
|
||||
return list(set(community_ids))
|
||||
|
||||
def generate_answer_from_summary(self, community_summary, query):
|
||||
"""Generate an answer from a community summary based on a given query using LLM."""
|
||||
prompt = (
|
||||
f"Given the community summary: {community_summary}, "
|
||||
f"how would you answer the following query? Query: {query}"
|
||||
)
|
||||
messages = [
|
||||
ChatMessage(role="system", content=prompt),
|
||||
ChatMessage(
|
||||
role="user",
|
||||
content="I need an answer based on the above information.",
|
||||
),
|
||||
]
|
||||
response = self.llm.chat(messages)
|
||||
cleaned_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
|
||||
return cleaned_response
|
||||
|
||||
def aggregate_answers(self, community_answers):
|
||||
"""Aggregate individual community answers into a final, coherent response."""
|
||||
prompt = "Combine the following intermediate answers into a final, concise response."
|
||||
messages = [
|
||||
ChatMessage(role="system", content=prompt),
|
||||
ChatMessage(
|
||||
role="user",
|
||||
content=f"Intermediate answers: {community_answers}",
|
||||
),
|
||||
]
|
||||
final_response = self.llm.chat(messages)
|
||||
cleaned_final_response = re.sub(
|
||||
r"^assistant:\s*", "", str(final_response)
|
||||
).strip()
|
||||
return cleaned_final_response
|
||||
|
||||
def custom_parse_fn(response_str: str) -> Any:
|
||||
"""Custom parser for LLM responses that extract entities and relationships"""
|
||||
json_pattern = r"\{.*\}"
|
||||
match = re.search(json_pattern, response_str, re.DOTALL)
|
||||
entities = []
|
||||
relationships = []
|
||||
|
||||
if not match:
|
||||
return entities, relationships
|
||||
|
||||
json_str = match.group(0)
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
entities = [
|
||||
(
|
||||
entity["entity_name"],
|
||||
entity["entity_type"],
|
||||
entity.get("entity_description", f"Description of {entity['entity_name']}"),
|
||||
)
|
||||
for entity in data.get("entities", [])
|
||||
]
|
||||
relationships = [
|
||||
(
|
||||
relation["source_entity"],
|
||||
relation["target_entity"],
|
||||
relation["relation"],
|
||||
relation.get("relationship_description", f"Relationship between {relation['source_entity']} and {relation['target_entity']}"),
|
||||
)
|
||||
for relation in data.get("relationships", [])
|
||||
]
|
||||
return entities, relationships
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
print(f"Error parsing response: {e}")
|
||||
print(f"Problematic JSON: {json_str[:200]}...")
|
||||
return entities, relationships
|
||||
|
||||
# Define the prompt template for triple extraction
|
||||
KG_TRIPLET_EXTRACT_TMPL = """
|
||||
-Goal-
|
||||
Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities.
|
||||
|
||||
Given the text, extract up to {max_knowledge_triplets} entity-relation triplets.
|
||||
|
||||
-Steps-
|
||||
1. Identify all entities. For each identified entity, extract the following information:
|
||||
- entity_name: Name of the entity, capitalized
|
||||
- entity_type: Type of the entity
|
||||
- entity_description: Comprehensive description of the entity's attributes and activities
|
||||
|
||||
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
|
||||
For each pair of related entities, extract the following information:
|
||||
- source_entity: name of the source entity, as identified in step 1
|
||||
- target_entity: name of the target entity, as identified in step 1
|
||||
- relation: relationship between source_entity and target_entity
|
||||
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
|
||||
|
||||
3. Output Formatting:
|
||||
- Return the result in valid JSON format with two keys: 'entities' (list of entity objects) and 'relationships' (list of relationship objects).
|
||||
- Exclude any text outside the JSON structure (e.g., no explanations or comments).
|
||||
- If no entities or relationships are identified, return empty lists: { "entities": [], "relationships": [] }.
|
||||
|
||||
-Real Data-
|
||||
######################
|
||||
text: {text}
|
||||
######################
|
||||
output:
|
||||
"""
|
||||
|
||||
def main():
|
||||
print("Starting GraphRAG document processing...")
|
||||
|
||||
# Load documents from specified directory
|
||||
documents = SimpleDirectoryReader(
|
||||
input_dir="supporting_files/files_for_rag_store"
|
||||
).load_data()
|
||||
|
||||
print(f"Loaded {len(documents)} documents")
|
||||
|
||||
# Create nodes using a sentence splitter
|
||||
splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20)
|
||||
nodes = splitter.get_nodes_from_documents(documents)
|
||||
|
||||
print(f"Created {len(nodes)} nodes from documents")
|
||||
|
||||
# Initialize the LLM
|
||||
llm = OpenAI(model="gpt-4")
|
||||
|
||||
# Create the knowledge graph extractor
|
||||
kg_extractor = GraphRAGExtractor(
|
||||
llm=llm,
|
||||
extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
|
||||
max_paths_per_chunk=2,
|
||||
parse_fn=custom_parse_fn,
|
||||
)
|
||||
|
||||
# Connect to Neo4j running in Docker
|
||||
neo4j_username = "neo4j"
|
||||
neo4j_password = "tavern-easy-museum-arthur-coconut-3483"
|
||||
neo4j_url = "bolt://localhost:7687"
|
||||
|
||||
print(f"Connecting to Neo4j at {neo4j_url} with username '{neo4j_username}'")
|
||||
|
||||
# Create GraphRAGStore (our extended Neo4j store)
|
||||
try:
|
||||
graph_store = GraphRAGStore(
|
||||
username=neo4j_username,
|
||||
password=neo4j_password,
|
||||
url=neo4j_url
|
||||
)
|
||||
print("Successfully connected to Neo4j database")
|
||||
except Exception as e:
|
||||
print(f"Error connecting to Neo4j: {e}")
|
||||
print("Falling back to in-memory graph store. Some features may be limited.")
|
||||
# Fallback to in-memory graph store if Neo4j connection fails
|
||||
from llama_index.core.graph_stores import SimplePropertyGraphStore
|
||||
graph_store = SimplePropertyGraphStore()
|
||||
|
||||
# Build the index
|
||||
index = PropertyGraphIndex(
|
||||
nodes=nodes,
|
||||
kg_extractors=[kg_extractor],
|
||||
property_graph_store=graph_store,
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
print("Building graph communities...")
|
||||
try:
|
||||
# Build communities for graph-based querying
|
||||
# Only for GraphRAGStore, not for SimplePropertyGraphStore
|
||||
if hasattr(graph_store, 'build_communities'):
|
||||
graph_store.build_communities()
|
||||
print("Communities built successfully")
|
||||
else:
|
||||
print("Skipping community building (not using Neo4j)")
|
||||
except Exception as e:
|
||||
print(f"Error building communities: {e}")
|
||||
|
||||
# Create the query engine
|
||||
query_engine = GraphRAGQueryEngine(
|
||||
graph_store=graph_store,
|
||||
llm=llm,
|
||||
index=index,
|
||||
similarity_top_k=10,
|
||||
)
|
||||
|
||||
# Simple interactive query loop
|
||||
print("\n--- GraphRAG Query System Ready ---")
|
||||
print("Type 'exit' to quit")
|
||||
|
||||
while True:
|
||||
query = input("\nEnter your query: ")
|
||||
|
||||
if query.lower() in ('exit', 'quit'):
|
||||
break
|
||||
|
||||
try:
|
||||
response = query_engine.custom_query(query)
|
||||
print("\nResponse:")
|
||||
print(response)
|
||||
except Exception as e:
|
||||
print(f"Error processing query: {e}")
|
||||
|
||||
print("GraphRAG session ended")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
882
graph_rag_integration.py
Normal file
882
graph_rag_integration.py
Normal file
|
|
@ -0,0 +1,882 @@
|
|||
"""
|
||||
HP GraphRAG Integration
|
||||
|
||||
Integrates GraphRAG functionality into the HP RAG pipeline.
|
||||
- GraphRAG for knowledge graph construction from semantically split nodes
|
||||
- Community detection and summarization for improved context retrieval
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import re
|
||||
import asyncio
|
||||
import networkx as nx
|
||||
from collections import defaultdict
|
||||
from typing import Any, List, Callable, Optional, Union, Dict
|
||||
from pathlib import Path
|
||||
|
||||
# Import LlamaIndex components
|
||||
from llama_index.core import Document, Settings
|
||||
from llama_index.core.node_parser import SentenceSplitter
|
||||
from llama_index.core import PropertyGraphIndex
|
||||
from llama_index.core.async_utils import run_jobs
|
||||
from llama_index.core.indices.property_graph.utils import default_parse_triplets_fn
|
||||
from llama_index.core.graph_stores.types import EntityNode, KG_NODES_KEY, KG_RELATIONS_KEY, Relation
|
||||
from llama_index.core.llms.llm import LLM
|
||||
from llama_index.core.prompts import PromptTemplate
|
||||
from llama_index.core.prompts.default_prompts import DEFAULT_KG_TRIPLET_EXTRACT_PROMPT
|
||||
from llama_index.core.schema import TransformComponent, BaseNode
|
||||
from llama_index.core.query_engine import CustomQueryEngine
|
||||
from llama_index.llms.openai import OpenAI
|
||||
from llama_index.core.llms import ChatMessage
|
||||
from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.vector_stores.types import VectorStoreInfo, MetadataInfo
|
||||
from llama_index.core.retrievers import VectorIndexRetriever
|
||||
|
||||
# Community detection (using NetworkX instead of graspologic as a fallback)
|
||||
try:
|
||||
from community import best_partition # python-louvain package
|
||||
except ImportError:
|
||||
print("Community detection package not found, using NetworkX built-in community detection")
|
||||
|
||||
# Import from our modules
|
||||
from utils import logger, log_structured
|
||||
from config import NEO4J_URL, NEO4J_USERNAME, NEO4J_PASSWORD
|
||||
import config
|
||||
|
||||
# Define the GraphRAGExtractor class
|
||||
class GraphRAGExtractor(TransformComponent):
|
||||
"""Extract triples from a graph.
|
||||
|
||||
Uses an LLM and a simple prompt + output parsing to
|
||||
extract paths (i.e. triples) and entity, relation descriptions
|
||||
from text.
|
||||
|
||||
Args:
|
||||
llm (LLM):
|
||||
The language model to use.
|
||||
extract_prompt (Union[str, PromptTemplate]):
|
||||
The prompt to use for extracting triples.
|
||||
parse_fn (callable):
|
||||
A function to parse the output of the language
|
||||
model.
|
||||
num_workers (int):
|
||||
The number of workers to use for parallel
|
||||
processing.
|
||||
max_paths_per_chunk (int):
|
||||
The maximum number of paths to extract per chunk.
|
||||
"""
|
||||
|
||||
llm: LLM
|
||||
extract_prompt: PromptTemplate
|
||||
parse_fn: Callable
|
||||
num_workers: int
|
||||
max_paths_per_chunk: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: Optional[LLM] = None,
|
||||
extract_prompt: Optional[Union[str, PromptTemplate]] = None,
|
||||
parse_fn: Callable = default_parse_triplets_fn,
|
||||
max_paths_per_chunk: int = 10,
|
||||
num_workers: int = 8,
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
from llama_index.core import Settings
|
||||
|
||||
if isinstance(extract_prompt, str):
|
||||
extract_prompt = PromptTemplate(extract_prompt)
|
||||
|
||||
super().__init__(
|
||||
llm=llm or Settings.llm,
|
||||
extract_prompt=extract_prompt or DEFAULT_KG_TRIPLET_EXTRACT_PROMPT,
|
||||
parse_fn=parse_fn,
|
||||
num_workers=num_workers,
|
||||
max_paths_per_chunk=max_paths_per_chunk,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "GraphExtractor"
|
||||
|
||||
def __call__(
|
||||
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
|
||||
) -> List[BaseNode]:
|
||||
"""Extract triples from nodes."""
|
||||
return asyncio.run(
|
||||
self.acall(nodes, show_progress=show_progress, **kwargs)
|
||||
)
|
||||
|
||||
async def _aextract(self, node: BaseNode) -> BaseNode:
|
||||
"""Extract triples from a node."""
|
||||
assert hasattr(node, "text")
|
||||
|
||||
text = node.get_content(metadata_mode="llm")
|
||||
try:
|
||||
llm_response = await self.llm.apredict(
|
||||
self.extract_prompt,
|
||||
text=text,
|
||||
max_knowledge_triplets=self.max_paths_per_chunk,
|
||||
)
|
||||
entities, entities_relationship = self.parse_fn(llm_response)
|
||||
except ValueError:
|
||||
entities = []
|
||||
entities_relationship = []
|
||||
|
||||
existing_nodes = node.metadata.pop(KG_NODES_KEY, [])
|
||||
existing_relations = node.metadata.pop(KG_RELATIONS_KEY, [])
|
||||
|
||||
entity_metadata = node.metadata.copy()
|
||||
for entity, entity_type, description in entities:
|
||||
entity_metadata["entity_description"] = description
|
||||
entity_node = EntityNode(
|
||||
name=entity, label=entity_type,
|
||||
properties=entity_metadata
|
||||
)
|
||||
existing_nodes.append(entity_node)
|
||||
|
||||
relation_metadata = node.metadata.copy()
|
||||
for triple in entities_relationship:
|
||||
subj, obj, rel, description = triple
|
||||
relation_metadata["relationship_description"] = description
|
||||
rel_node = Relation(
|
||||
label=rel,
|
||||
source_id=subj,
|
||||
target_id=obj,
|
||||
properties=relation_metadata,
|
||||
)
|
||||
existing_relations.append(rel_node)
|
||||
|
||||
node.metadata[KG_NODES_KEY] = existing_nodes
|
||||
node.metadata[KG_RELATIONS_KEY] = existing_relations
|
||||
return node
|
||||
|
||||
async def acall(
|
||||
self, nodes: List[BaseNode], show_progress: bool = False, **kwargs: Any
|
||||
) -> List[BaseNode]:
|
||||
"""Extract triples from nodes async."""
|
||||
jobs = []
|
||||
for node in nodes:
|
||||
jobs.append(self._aextract(node))
|
||||
|
||||
return await run_jobs(
|
||||
jobs,
|
||||
workers=self.num_workers,
|
||||
show_progress=show_progress,
|
||||
desc="Extracting paths from text",
|
||||
)
|
||||
|
||||
# Define the GraphRAGStore class (integrating with Neo4j)
|
||||
import pickle
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
class GraphRAGStore:
|
||||
community_summary = {}
|
||||
entity_info = None
|
||||
max_cluster_size = 5
|
||||
property_graph_store = None
|
||||
communities_built = False # Track if communities have been built
|
||||
|
||||
# Path for cached community data
|
||||
CACHE_DIR = Path("index_storage/graphrag_cache")
|
||||
COMMUNITY_CACHE_FILE = CACHE_DIR / "community_summary.pickle"
|
||||
ENTITY_INFO_CACHE_FILE = CACHE_DIR / "entity_info.pickle"
|
||||
|
||||
def __init__(self, property_graph_store):
|
||||
"""Initialize with a property_graph_store (Neo4j or in-memory)."""
|
||||
self.property_graph_store = property_graph_store
|
||||
self.community_summary = {}
|
||||
self.entity_info = None
|
||||
self.communities_built = False
|
||||
|
||||
# Ensure cache directory exists
|
||||
os.makedirs(self.CACHE_DIR, exist_ok=True)
|
||||
|
||||
def add_nodes(self, nodes):
|
||||
"""Add nodes to the property graph store."""
|
||||
return self.property_graph_store.add_nodes(nodes)
|
||||
|
||||
def add_relationships(self, relationships):
|
||||
"""Add relationships to the property graph store."""
|
||||
return self.property_graph_store.add_relationships(relationships)
|
||||
|
||||
def get_triplets(self):
|
||||
"""Get triplets from the property graph store."""
|
||||
return self.property_graph_store.get_triplets()
|
||||
|
||||
def generate_community_summary(self, text):
|
||||
"""Generate summary for a given text using an LLM with handling for large contexts."""
|
||||
|
||||
# Check if text is too long and chunk if needed
|
||||
if len(text) > 30000: # Approximate character limit
|
||||
log_structured('info', f'Community text is large ({len(text)} chars). Chunking for summarization.')
|
||||
# Split into smaller chunks (simple approach)
|
||||
chunks = [text[i:i+30000] for i in range(0, len(text), 30000)]
|
||||
summaries = []
|
||||
|
||||
for i, chunk in enumerate(chunks):
|
||||
try:
|
||||
# Use GPT-4o-mini model for better cost efficiency
|
||||
llm = OpenAI(model="gpt-4o-mini")
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role="system",
|
||||
content="Summarize these knowledge graph relationships concisely."
|
||||
),
|
||||
ChatMessage(role="user", content=chunk),
|
||||
]
|
||||
response = llm.chat(messages)
|
||||
summaries.append(str(response).strip())
|
||||
log_structured('info', f'Successfully summarized community chunk {i+1}/{len(chunks)}')
|
||||
except Exception as e:
|
||||
log_structured('error', f'Error summarizing community chunk {i+1}/{len(chunks)}: {e}')
|
||||
|
||||
# Then summarize the summaries
|
||||
if summaries:
|
||||
final_summary_text = "\n\n".join(summaries)
|
||||
try:
|
||||
llm = OpenAI(model="gpt-4o-mini")
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role="system",
|
||||
content="Create a coherent summary from these partial summaries."
|
||||
),
|
||||
ChatMessage(role="user", content=final_summary_text),
|
||||
]
|
||||
response = llm.chat(messages)
|
||||
return str(response).strip()
|
||||
except Exception as e:
|
||||
log_structured('error', f'Error creating final summary from chunks: {e}')
|
||||
# Return the concatenated summaries if we can't summarize them
|
||||
return final_summary_text
|
||||
else:
|
||||
return "Unable to generate community summary due to size limitations."
|
||||
|
||||
# For normal size text, use the larger model directly
|
||||
try:
|
||||
# Use GPT-4o-mini model for better cost efficiency
|
||||
llm = OpenAI(model="gpt-4o-mini")
|
||||
messages = [
|
||||
ChatMessage(
|
||||
role="system",
|
||||
content=(
|
||||
"You are provided with a set of "
|
||||
"relationships from a knowledge graph, each represented as "
|
||||
"entity1->entity2->relation-"
|
||||
">relationship_description. Your task is to create a summary of "
|
||||
"these relationships. The summary should include "
|
||||
"the names of the entities involved and a concise synthesis "
|
||||
"of the relationship descriptions. The "
|
||||
"goal is to capture the most critical and relevant details that "
|
||||
"highlight the nature and significance of "
|
||||
"each relationship. Ensure that the summary is coherent and "
|
||||
"integrates the information in a way that "
|
||||
"emphasizes the key aspects of the relationships."
|
||||
),
|
||||
),
|
||||
ChatMessage(role="user", content=text),
|
||||
]
|
||||
response = llm.chat(messages)
|
||||
clean_response = re.sub(r"^assistant:\s*", "", str(response)).strip()
|
||||
return clean_response
|
||||
except Exception as e:
|
||||
log_structured('error', f'Error generating community summary: {e}')
|
||||
return f"Error summarizing community: {str(e)}"
|
||||
|
||||
def build_communities(self):
|
||||
"""Builds communities from the graph and summarizes them."""
|
||||
# Skip if communities are already built in this session
|
||||
if self.communities_built:
|
||||
log_structured('info', 'Communities already built in this session, skipping rebuild')
|
||||
return
|
||||
|
||||
# First check if we can load from cache
|
||||
if self.load_from_cache():
|
||||
log_structured('info', 'Using cached community data instead of rebuilding')
|
||||
self.communities_built = True
|
||||
return
|
||||
|
||||
log_structured('info', 'Building communities from graph data')
|
||||
nx_graph = self._create_nx_graph()
|
||||
|
||||
# Use either Leiden algorithm (from graspologic) or an alternative
|
||||
try:
|
||||
from graspologic.partition import hierarchical_leiden
|
||||
community_hierarchical_clusters = hierarchical_leiden(
|
||||
nx_graph, max_cluster_size=self.max_cluster_size
|
||||
)
|
||||
self.entity_info, community_info = self._collect_community_info(
|
||||
nx_graph, community_hierarchical_clusters
|
||||
)
|
||||
except ImportError:
|
||||
# Fallback to community detection using NetworkX or python-louvain
|
||||
try:
|
||||
from community import best_partition
|
||||
partition = best_partition(nx_graph)
|
||||
# Reformat partition data to expected structure
|
||||
clusters = []
|
||||
for node, cluster_id in partition.items():
|
||||
class Cluster:
|
||||
def __init__(self, node, cluster):
|
||||
self.node = node
|
||||
self.cluster = cluster
|
||||
clusters.append(Cluster(node, cluster_id))
|
||||
self.entity_info, community_info = self._collect_community_info(
|
||||
nx_graph, clusters
|
||||
)
|
||||
except ImportError:
|
||||
# Use NetworkX's built-in community detection
|
||||
from networkx.algorithms import community
|
||||
communities = community.greedy_modularity_communities(nx_graph)
|
||||
clusters = []
|
||||
for i, comm in enumerate(communities):
|
||||
for node in comm:
|
||||
class Cluster:
|
||||
def __init__(self, node, cluster):
|
||||
self.node = node
|
||||
self.cluster = cluster
|
||||
clusters.append(Cluster(node, i))
|
||||
self.entity_info, community_info = self._collect_community_info(
|
||||
nx_graph, clusters
|
||||
)
|
||||
|
||||
self._summarize_communities(community_info)
|
||||
|
||||
# Cache the results after building
|
||||
self.save_to_cache()
|
||||
|
||||
# Mark communities as built for this session
|
||||
self.communities_built = True
|
||||
|
||||
def _create_nx_graph(self):
|
||||
"""Converts internal graph representation to NetworkX graph."""
|
||||
nx_graph = nx.Graph()
|
||||
triplets = self.get_triplets()
|
||||
for entity1, relation, entity2 in triplets:
|
||||
nx_graph.add_node(entity1.name)
|
||||
nx_graph.add_node(entity2.name)
|
||||
nx_graph.add_edge(
|
||||
relation.source_id,
|
||||
relation.target_id,
|
||||
relationship=relation.label,
|
||||
description=relation.properties.get("relationship_description", "No description provided"),
|
||||
)
|
||||
return nx_graph
|
||||
|
||||
def _collect_community_info(self, nx_graph, clusters):
|
||||
"""
|
||||
Collect information for each node based on their community,
|
||||
allowing entities to belong to multiple clusters.
|
||||
"""
|
||||
entity_info = defaultdict(set)
|
||||
community_info = defaultdict(list)
|
||||
|
||||
for item in clusters:
|
||||
node = item.node
|
||||
cluster_id = item.cluster
|
||||
|
||||
# Update entity_info
|
||||
entity_info[node].add(cluster_id)
|
||||
|
||||
for neighbor in nx_graph.neighbors(node):
|
||||
edge_data = nx_graph.get_edge_data(node, neighbor)
|
||||
if edge_data:
|
||||
detail = f"{node} -> {neighbor} -> {edge_data['relationship']} -> {edge_data['description']}"
|
||||
community_info[cluster_id].append(detail)
|
||||
|
||||
# Convert sets to lists for easier serialization if needed
|
||||
entity_info = {k: list(v) for k, v in entity_info.items()}
|
||||
|
||||
return dict(entity_info), dict(community_info)
|
||||
|
||||
def _summarize_communities(self, community_info):
|
||||
"""Generate and store summaries for each community."""
|
||||
for community_id, details in community_info.items():
|
||||
details_text = "\n".join(details) + "." # Ensure it ends with a period
|
||||
self.community_summary[community_id] = self.generate_community_summary(details_text)
|
||||
|
||||
def save_to_cache(self):
|
||||
"""Save community data to disk cache."""
|
||||
try:
|
||||
# Save community summary
|
||||
with open(self.COMMUNITY_CACHE_FILE, 'wb') as f:
|
||||
pickle.dump(self.community_summary, f)
|
||||
|
||||
# Save entity info
|
||||
with open(self.ENTITY_INFO_CACHE_FILE, 'wb') as f:
|
||||
pickle.dump(self.entity_info, f)
|
||||
|
||||
log_structured('info', 'Successfully cached GraphRAG community data',
|
||||
{'community_count': len(self.community_summary),
|
||||
'entity_count': len(self.entity_info) if self.entity_info else 0})
|
||||
return True
|
||||
except Exception as e:
|
||||
log_structured('error', f'Error saving GraphRAG cache: {e}')
|
||||
return False
|
||||
|
||||
def load_from_cache(self):
|
||||
"""Load community data from disk cache if available."""
|
||||
if not self.COMMUNITY_CACHE_FILE.exists() or not self.ENTITY_INFO_CACHE_FILE.exists():
|
||||
log_structured('info', 'GraphRAG cache files not found, will build communities from scratch')
|
||||
return False
|
||||
|
||||
try:
|
||||
# Load community summary
|
||||
with open(self.COMMUNITY_CACHE_FILE, 'rb') as f:
|
||||
self.community_summary = pickle.load(f)
|
||||
|
||||
# Load entity info
|
||||
with open(self.ENTITY_INFO_CACHE_FILE, 'rb') as f:
|
||||
self.entity_info = pickle.load(f)
|
||||
|
||||
log_structured('info', 'Successfully loaded GraphRAG community data from cache',
|
||||
{'community_count': len(self.community_summary),
|
||||
'entity_count': len(self.entity_info) if self.entity_info else 0})
|
||||
|
||||
# Mark communities as built when successfully loaded from cache
|
||||
self.communities_built = True
|
||||
return True
|
||||
except Exception as e:
|
||||
log_structured('error', f'Error loading GraphRAG cache: {e}')
|
||||
# Reset to empty in case of partial load
|
||||
self.community_summary = {}
|
||||
self.entity_info = None
|
||||
self.communities_built = False
|
||||
return False
|
||||
|
||||
def get_community_summaries(self):
|
||||
"""Returns the community summaries, building them if not already done."""
|
||||
if not self.community_summary:
|
||||
# Try to load from cache first
|
||||
if not self.load_from_cache():
|
||||
# If cache load fails, build from scratch
|
||||
self.build_communities()
|
||||
# Cache the results for next time
|
||||
self.save_to_cache()
|
||||
return self.community_summary
|
||||
|
||||
# Define the GraphRAGQueryEngine class
|
||||
from typing import Dict, Any
|
||||
|
||||
class GraphRAGQueryEngine:
|
||||
"""Query engine that combines vector retrieval with graph-based community retrieval."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_retriever: VectorIndexRetriever,
|
||||
graph_store: GraphRAGStore,
|
||||
llm: Optional[LLM] = None,
|
||||
similarity_top_k: int = 20
|
||||
):
|
||||
"""Initialize with both a vector retriever and graph store."""
|
||||
# Initialize all required fields
|
||||
self.vector_retriever = vector_retriever
|
||||
self.graph_store = graph_store
|
||||
self.llm = llm or Settings.llm
|
||||
self.similarity_top_k = similarity_top_k
|
||||
|
||||
# Check if communities are built, but don't try to build them here
|
||||
# since that might cause errors with large graphs
|
||||
if not hasattr(self.graph_store, 'entity_info') or self.graph_store.entity_info is None:
|
||||
log_structured('warning', 'GraphRAGQueryEngine initialized without community data. Vector retrieval will still work, but community retrieval may be limited.')
|
||||
|
||||
def custom_query(self, query_str: str) -> Dict:
|
||||
"""Process query using both vector retrieval and community-based approach."""
|
||||
log_structured('info', 'GraphRAG query engine: Starting dual retrieval', {'query': query_str})
|
||||
|
||||
# Step 1: Get vector search results
|
||||
vector_nodes = self.vector_retriever.retrieve(query_str)
|
||||
vector_context = "\n\n".join([node.node.get_content() for node in vector_nodes])
|
||||
log_structured('info', 'GraphRAG query engine: Vector retrieval complete',
|
||||
{'node_count': len(vector_nodes)})
|
||||
|
||||
# Step 2: Get GraphRAG community results (if communities exist)
|
||||
graphrag_context = ""
|
||||
community_ids = []
|
||||
|
||||
if hasattr(self.graph_store, 'entity_info') and self.graph_store.entity_info is not None:
|
||||
try:
|
||||
entities = self.get_entities(query_str, vector_nodes)
|
||||
community_ids = self.retrieve_entity_communities(self.graph_store.entity_info, entities)
|
||||
|
||||
try:
|
||||
community_summaries = self.graph_store.get_community_summaries()
|
||||
|
||||
if community_ids:
|
||||
filtered_summaries = {id: summary for id, summary in community_summaries.items()
|
||||
if id in community_ids}
|
||||
graphrag_context = "\n\n".join(filtered_summaries.values())
|
||||
log_structured('info', 'GraphRAG query engine: Community retrieval complete',
|
||||
{'community_count': len(filtered_summaries)})
|
||||
else:
|
||||
log_structured('info', 'GraphRAG query engine: No relevant communities found')
|
||||
except Exception as e:
|
||||
log_structured('error', f'Error getting community summaries: {e}')
|
||||
# Continue without graph context
|
||||
except Exception as e:
|
||||
log_structured('error', f'Error during community retrieval: {e}')
|
||||
# Continue with just vector context
|
||||
else:
|
||||
log_structured('warning', 'GraphRAG query engine: No community data available. Using only vector retrieval.')
|
||||
|
||||
# Step 3: Combine contexts and generate answer
|
||||
combined_result = {
|
||||
"vector_context": vector_context,
|
||||
"graphrag_context": graphrag_context,
|
||||
"vector_nodes": vector_nodes,
|
||||
"community_ids": community_ids
|
||||
}
|
||||
|
||||
return combined_result
|
||||
|
||||
def get_entities(self, query_str, vector_nodes):
|
||||
"""Extract entities from vector nodes that match the query."""
|
||||
entities = set()
|
||||
|
||||
# Extract entities from the retrieved nodes
|
||||
for node_with_score in vector_nodes:
|
||||
node = node_with_score.node
|
||||
if hasattr(node, 'metadata') and KG_NODES_KEY in node.metadata:
|
||||
for entity_node in node.metadata[KG_NODES_KEY]:
|
||||
if hasattr(entity_node, 'name'):
|
||||
entities.add(entity_node.name)
|
||||
|
||||
# If no entities were found in metadata, try extracting them from text
|
||||
if not entities:
|
||||
pattern = r"(?:^|\s)([A-Z][a-zA-Z0-9\s]+)(?:\s|$)"
|
||||
for node_with_score in vector_nodes:
|
||||
matches = re.findall(pattern, node_with_score.node.get_content())
|
||||
entities.update(matches)
|
||||
|
||||
log_structured('debug', 'GraphRAG query engine: Extracted entities',
|
||||
{'entities': list(entities), 'count': len(entities)})
|
||||
return list(entities)
|
||||
|
||||
def retrieve_entity_communities(self, entity_info, entities):
|
||||
"""
|
||||
Retrieve cluster information for given entities,
|
||||
allowing for multiple clusters per entity.
|
||||
|
||||
Args:
|
||||
entity_info (dict): Dictionary mapping entities to their cluster IDs (list).
|
||||
entities (list): List of entity names to retrieve information for.
|
||||
|
||||
Returns:
|
||||
List of community or cluster IDs to which an entity belongs.
|
||||
"""
|
||||
community_ids = []
|
||||
|
||||
for entity in entities:
|
||||
if entity in entity_info:
|
||||
community_ids.extend(entity_info[entity])
|
||||
else:
|
||||
# Try case-insensitive matching as fallback
|
||||
for stored_entity, clusters in entity_info.items():
|
||||
if stored_entity.lower() == entity.lower():
|
||||
community_ids.extend(clusters)
|
||||
break
|
||||
|
||||
return list(set(community_ids))
|
||||
|
||||
def custom_parse_fn(response_str: str) -> Any:
|
||||
"""Custom parser for LLM responses that extract entities and relationships"""
|
||||
json_pattern = r"\{.*\}"
|
||||
match = re.search(json_pattern, response_str, re.DOTALL)
|
||||
entities = []
|
||||
relationships = []
|
||||
|
||||
if not match:
|
||||
return entities, relationships
|
||||
|
||||
json_str = match.group(0)
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
entities = [
|
||||
(
|
||||
entity["entity_name"],
|
||||
entity["entity_type"],
|
||||
entity.get("entity_description", f"Description of {entity['entity_name']}"),
|
||||
)
|
||||
for entity in data.get("entities", [])
|
||||
]
|
||||
relationships = [
|
||||
(
|
||||
relation["source_entity"],
|
||||
relation["target_entity"],
|
||||
relation["relation"],
|
||||
relation.get("relationship_description", f"Relationship between {relation['source_entity']} and {relation['target_entity']}"),
|
||||
)
|
||||
for relation in data.get("relationships", [])
|
||||
]
|
||||
return entities, relationships
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
log_structured('error', f"Error parsing response: {e}", {'json_str': json_str[:200]})
|
||||
return entities, relationships
|
||||
|
||||
# Define the prompt template for triple extraction
|
||||
KG_TRIPLET_EXTRACT_TMPL = """
|
||||
-Goal-
|
||||
Given a text document, identify all entities and their entity types from the text and all relationships among the identified entities.
|
||||
|
||||
Given the text, extract up to {max_knowledge_triplets} entity-relation triplets.
|
||||
|
||||
-Steps-
|
||||
1. Identify all entities. For each identified entity, extract the following information:
|
||||
- entity_name: Name of the entity, capitalized
|
||||
- entity_type: Type of the entity
|
||||
- entity_description: Comprehensive description of the entity's attributes and activities
|
||||
|
||||
2. From the entities identified in step 1, identify all pairs of (source_entity, target_entity) that are *clearly related* to each other.
|
||||
For each pair of related entities, extract the following information:
|
||||
- source_entity: name of the source entity, as identified in step 1
|
||||
- target_entity: name of the target entity, as identified in step 1
|
||||
- relation: relationship between source_entity and target_entity
|
||||
- relationship_description: explanation as to why you think the source entity and the target entity are related to each other
|
||||
|
||||
3. Output Formatting:
|
||||
- Return the result in valid JSON format with two keys: 'entities' (list of entity objects) and 'relationships' (list of relationship objects).
|
||||
- Exclude any text outside the JSON structure (e.g., no explanations or comments).
|
||||
- If no entities or relationships are identified, return empty lists: { "entities": [], "relationships": [] }.
|
||||
|
||||
-Real Data-
|
||||
######################
|
||||
text: {text}
|
||||
######################
|
||||
output:
|
||||
"""
|
||||
|
||||
def create_graph_components(llm, nodes=None, max_paths_per_chunk=10, force_reindex=False):
|
||||
"""
|
||||
Create GraphRAG components for the HP RAG pipeline.
|
||||
|
||||
Args:
|
||||
llm: The LLM to use for graph extraction and querying
|
||||
nodes: List of nodes to create the graph from (only used if indexing is needed)
|
||||
max_paths_per_chunk: Maximum number of paths to extract per chunk
|
||||
force_reindex: If True, always recreate the index even if content exists
|
||||
|
||||
Returns:
|
||||
tuple: (graph_store, property_graph_index)
|
||||
"""
|
||||
log_structured('info', 'Creating GraphRAG components')
|
||||
|
||||
# Note: The graph_store object created here will automatically:
|
||||
# 1. Try to load community data from cache files when build_communities() is called
|
||||
# 2. Save to cache after building communities if loading failed
|
||||
|
||||
# Connect to Neo4j - hard error if not available
|
||||
property_graph_store = None
|
||||
try:
|
||||
log_structured('info', f'Connecting to Neo4j at {NEO4J_URL}')
|
||||
property_graph_store = Neo4jPropertyGraphStore(
|
||||
username=NEO4J_USERNAME,
|
||||
password=NEO4J_PASSWORD,
|
||||
url=NEO4J_URL
|
||||
)
|
||||
log_structured('info', 'Successfully connected to Neo4j database')
|
||||
except Exception as e:
|
||||
log_structured('critical', f'FATAL ERROR: Cannot connect to Neo4j: {e}')
|
||||
raise RuntimeError(f"Neo4j connection failed. This application requires Neo4j to be running. Error: {e}")
|
||||
|
||||
# Create GraphRAGStore wrapper
|
||||
graph_store = GraphRAGStore(property_graph_store)
|
||||
|
||||
# Check if Neo4j already has content
|
||||
triplets = graph_store.get_triplets()
|
||||
has_existing_content = len(triplets) > 0
|
||||
|
||||
log_structured('info', f'Neo4j check: Found {len(triplets)} triplets')
|
||||
|
||||
if has_existing_content and not force_reindex:
|
||||
log_structured('info', f'Neo4j already contains {len(triplets)} triplets. Skipping indexing.')
|
||||
|
||||
# Create a minimal PropertyGraphIndex without indexing
|
||||
property_graph_index = PropertyGraphIndex(
|
||||
nodes=[], # Empty nodes since we're not indexing
|
||||
property_graph_store=property_graph_store,
|
||||
)
|
||||
|
||||
# Build communities from existing data (if not already built)
|
||||
if not graph_store.communities_built:
|
||||
log_structured('info', 'Building graph communities from existing Neo4j data')
|
||||
try:
|
||||
graph_store.build_communities()
|
||||
log_structured('info', 'Communities built successfully')
|
||||
except Exception as e:
|
||||
log_structured('error', f'Error building communities: {e}')
|
||||
else:
|
||||
log_structured('info', 'Communities already built, skipping rebuild')
|
||||
else:
|
||||
# Need to perform indexing
|
||||
if not nodes:
|
||||
log_structured('error', 'No nodes provided for indexing and Neo4j is empty or force_reindex=True')
|
||||
raise ValueError("Nodes must be provided for indexing when Neo4j is empty or force_reindex=True")
|
||||
|
||||
# Create the knowledge graph extractor
|
||||
kg_extractor = GraphRAGExtractor(
|
||||
llm=llm,
|
||||
extract_prompt=KG_TRIPLET_EXTRACT_TMPL,
|
||||
max_paths_per_chunk=max_paths_per_chunk,
|
||||
parse_fn=custom_parse_fn,
|
||||
)
|
||||
|
||||
if has_existing_content and force_reindex:
|
||||
log_structured('info', 'Force reindexing requested. Clearing existing Neo4j data.')
|
||||
try:
|
||||
# Try to clear the graph using Neo4j's native query
|
||||
# Note: This requires the Neo4j APOC plugin to be installed
|
||||
from neo4j import GraphDatabase
|
||||
driver = GraphDatabase.driver(NEO4J_URL, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))
|
||||
with driver.session() as session:
|
||||
session.run("MATCH (n) DETACH DELETE n")
|
||||
driver.close()
|
||||
log_structured('info', 'Successfully cleared Neo4j database')
|
||||
except Exception as e:
|
||||
log_structured('warning', f'Error clearing Neo4j database: {e}. Proceeding with indexing anyway.')
|
||||
|
||||
# Build the property graph index
|
||||
log_structured('info', 'Building PropertyGraphIndex', {'node_count': len(nodes)})
|
||||
property_graph_index = PropertyGraphIndex(
|
||||
nodes=nodes,
|
||||
kg_extractors=[kg_extractor],
|
||||
property_graph_store=property_graph_store,
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
# Build communities
|
||||
log_structured('info', 'Building graph communities')
|
||||
try:
|
||||
graph_store.build_communities()
|
||||
log_structured('info', 'Communities built successfully')
|
||||
except Exception as e:
|
||||
log_structured('error', f'Error building communities: {e}')
|
||||
|
||||
return graph_store, property_graph_index
|
||||
|
||||
def create_graphrag_query_engine(vector_retriever, graph_store, llm, similarity_top_k=20):
|
||||
"""
|
||||
Create GraphRAG query engine that combines vector and graph-based retrieval.
|
||||
|
||||
Args:
|
||||
vector_retriever: VectorIndexRetriever for standard retrieval
|
||||
graph_store: GraphRAGStore for community-based retrieval
|
||||
llm: LLM for generating answer
|
||||
similarity_top_k: Number of top results to retrieve
|
||||
|
||||
Returns:
|
||||
GraphRAGQueryEngine: Query engine for hybrid retrieval
|
||||
"""
|
||||
from utils import log_structured
|
||||
|
||||
try:
|
||||
# Explicitly validate inputs before passing to constructor
|
||||
if vector_retriever is None:
|
||||
raise ValueError("vector_retriever cannot be None")
|
||||
if graph_store is None:
|
||||
raise ValueError("graph_store cannot be None")
|
||||
if llm is None:
|
||||
raise ValueError("llm cannot be None")
|
||||
|
||||
# Log for debugging
|
||||
log_structured('debug', 'Creating GraphRAGQueryEngine with parameters', {
|
||||
'vector_retriever_type': type(vector_retriever).__name__,
|
||||
'graph_store_type': type(graph_store).__name__,
|
||||
'llm_type': type(llm).__name__,
|
||||
'similarity_top_k': similarity_top_k
|
||||
})
|
||||
|
||||
# Create the engine
|
||||
return GraphRAGQueryEngine(
|
||||
vector_retriever=vector_retriever,
|
||||
graph_store=graph_store,
|
||||
llm=llm,
|
||||
similarity_top_k=similarity_top_k,
|
||||
)
|
||||
except Exception as e:
|
||||
log_structured('error', f'Error in create_graphrag_query_engine: {e}')
|
||||
raise # Re-raise the exception for proper handling
|
||||
|
||||
def generate_final_answer(query, retrieval_result, llm):
|
||||
"""
|
||||
Generate a final answer using both vector and graph-based context.
|
||||
|
||||
Args:
|
||||
query: The user's query
|
||||
retrieval_result: Result from GraphRAGQueryEngine with vector and graph contexts
|
||||
llm: LLM for generating the final response
|
||||
|
||||
Returns:
|
||||
str: The final answer
|
||||
"""
|
||||
vector_context = retrieval_result.get("vector_context", "")
|
||||
graphrag_context = retrieval_result.get("graphrag_context", "")
|
||||
|
||||
# Log the contexts for debugging (truncated for brevity)
|
||||
log_structured('debug', 'Generating final answer with dual context', {
|
||||
'query': query,
|
||||
'vector_context_length': len(vector_context),
|
||||
'graphrag_context_length': len(graphrag_context)
|
||||
})
|
||||
|
||||
if not vector_context and not graphrag_context:
|
||||
return "I couldn't find any relevant information to answer your question."
|
||||
|
||||
# If no model was provided or we're forcing to use a specific model
|
||||
if llm is None or not hasattr(llm, 'chat'):
|
||||
# Fallback to gpt-4o-mini for better cost efficiency
|
||||
llm = OpenAI(model="gpt-4o-mini")
|
||||
log_structured('info', 'Using gpt-4o-mini model for final answer generation')
|
||||
|
||||
prompt = f"""
|
||||
Based on the following information from two different sources, please answer this question: {query}
|
||||
|
||||
SOURCE 1 - VECTOR RETRIEVAL:
|
||||
{vector_context}
|
||||
|
||||
SOURCE 2 - KNOWLEDGE GRAPH COMMUNITIES:
|
||||
{graphrag_context}
|
||||
|
||||
Answer the question based on all the provided information. If there are differences between the sources,
|
||||
try to reconcile them or note the discrepancy. Please be concise and direct.
|
||||
"""
|
||||
|
||||
messages = [
|
||||
ChatMessage(role="system", content=prompt),
|
||||
ChatMessage(role="user", content="Please provide a comprehensive answer based on all the information provided.")
|
||||
]
|
||||
|
||||
response = llm.chat(messages)
|
||||
|
||||
# Extract just the message content, not the entire response object
|
||||
if hasattr(response, 'message') and hasattr(response.message, 'content'):
|
||||
content = response.message.content
|
||||
elif hasattr(response, 'content'):
|
||||
content = response.content
|
||||
else:
|
||||
# Fallback: convert to string but clean it
|
||||
content = str(response)
|
||||
|
||||
# Clean any remaining thinking patterns from the response
|
||||
import re
|
||||
thinking_patterns = [
|
||||
r'(?i)Thought:.*?Action:.*?Action Input:.*', # Remove the specific pattern
|
||||
r'(?i)^Thought:.*', # Remove any line starting with "Thought:"
|
||||
r'(?i)Action:.*?Action Input:.*', # Remove Action/Action Input patterns
|
||||
r'(?i)^(Thought|Action|Observation):.*', # Remove ReAct patterns
|
||||
]
|
||||
|
||||
for pattern in thinking_patterns:
|
||||
content = re.sub(pattern, '', content, flags=re.DOTALL | re.MULTILINE)
|
||||
|
||||
# Clean up extra whitespace
|
||||
content = re.sub(r'\n{3,}', '\n\n', content)
|
||||
content = content.strip()
|
||||
|
||||
# Final safety check
|
||||
if not content or re.search(r'(?i)^(Thought|Action|Observation):', content):
|
||||
log_structured('warning', 'GraphRAG final answer still contains thinking patterns, using fallback')
|
||||
content = "I found relevant information in the HP marketing materials that can help answer your question. Please let me know if you need more specific details."
|
||||
|
||||
return content
|
||||
127
init_mongodb.py
Normal file
127
init_mongodb.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""
|
||||
MongoDB Initialization Script for HP Chatbot
|
||||
|
||||
This script initializes the MongoDB database with the necessary collections for the HP chatbot.
|
||||
It creates collections for users, conversations, and messages.
|
||||
|
||||
Usage:
|
||||
python init_mongodb.py
|
||||
"""
|
||||
|
||||
import pymongo
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import os
|
||||
import sys
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('mongodb_init.log')
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# MongoDB connection information
|
||||
MONGO_URI = "mongodb://hp:hp@localhost:27017/?authSource=hp_chatbot" # HP user with hp_chatbot auth source
|
||||
DB_NAME = "hp_chatbot"
|
||||
|
||||
# Collection names
|
||||
USERS_COLLECTION = "users"
|
||||
CONVERSATIONS_COLLECTION = "conversations"
|
||||
MESSAGES_COLLECTION = "messages"
|
||||
|
||||
def init_mongodb():
|
||||
"""Initialize MongoDB database and collections."""
|
||||
try:
|
||||
# Connect to MongoDB
|
||||
logger.info("Connecting to MongoDB...")
|
||||
client = pymongo.MongoClient(MONGO_URI)
|
||||
|
||||
# Test connection
|
||||
client.admin.command('ping')
|
||||
logger.info("Successfully connected to MongoDB")
|
||||
|
||||
# Create or access database
|
||||
db = client[DB_NAME]
|
||||
logger.info(f"Using database: {DB_NAME}")
|
||||
|
||||
# Create collections if they don't exist
|
||||
if USERS_COLLECTION not in db.list_collection_names():
|
||||
db.create_collection(USERS_COLLECTION)
|
||||
logger.info(f"Created collection: {USERS_COLLECTION}")
|
||||
|
||||
# Create indexes for users collection
|
||||
db[USERS_COLLECTION].create_index("username", unique=True)
|
||||
# Create a unique sparse index for email - only enforces uniqueness when email exists
|
||||
db[USERS_COLLECTION].create_index("email", unique=True, sparse=True)
|
||||
logger.info("Created indexes for users collection")
|
||||
|
||||
if CONVERSATIONS_COLLECTION not in db.list_collection_names():
|
||||
db.create_collection(CONVERSATIONS_COLLECTION)
|
||||
logger.info(f"Created collection: {CONVERSATIONS_COLLECTION}")
|
||||
|
||||
# Create indexes for conversations collection
|
||||
db[CONVERSATIONS_COLLECTION].create_index("user_id")
|
||||
db[CONVERSATIONS_COLLECTION].create_index("session_id", unique=True)
|
||||
db[CONVERSATIONS_COLLECTION].create_index("created_at")
|
||||
db[CONVERSATIONS_COLLECTION].create_index("last_updated")
|
||||
logger.info("Created indexes for conversations collection")
|
||||
|
||||
if MESSAGES_COLLECTION not in db.list_collection_names():
|
||||
db.create_collection(MESSAGES_COLLECTION)
|
||||
logger.info(f"Created collection: {MESSAGES_COLLECTION}")
|
||||
|
||||
# Create indexes for messages collection
|
||||
db[MESSAGES_COLLECTION].create_index("conversation_id")
|
||||
db[MESSAGES_COLLECTION].create_index("timestamp")
|
||||
logger.info("Created indexes for messages collection")
|
||||
|
||||
logger.info("MongoDB initialization completed successfully")
|
||||
return True
|
||||
|
||||
except pymongo.errors.ConnectionFailure as e:
|
||||
logger.error(f"Could not connect to MongoDB: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"An error occurred during MongoDB initialization: {e}")
|
||||
return False
|
||||
|
||||
def display_collection_info(client):
|
||||
"""Display information about the collections in the database."""
|
||||
db = client[DB_NAME]
|
||||
|
||||
logger.info("=== Database Structure ===")
|
||||
for collection_name in db.list_collection_names():
|
||||
count = db[collection_name].count_documents({})
|
||||
logger.info(f"Collection: {collection_name}, Documents: {count}")
|
||||
|
||||
# Display indexes
|
||||
indexes = db[collection_name].index_information()
|
||||
logger.info(f" Indexes: {list(indexes.keys())}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
if init_mongodb():
|
||||
# Display collection information
|
||||
client = pymongo.MongoClient(MONGO_URI)
|
||||
display_collection_info(client)
|
||||
|
||||
# Add sample user if none exist (optional)
|
||||
db = client[DB_NAME]
|
||||
if db[USERS_COLLECTION].count_documents({}) == 0:
|
||||
sample_user = {
|
||||
"username": "sample_user",
|
||||
"email": "sample@example.com",
|
||||
"created_at": datetime.utcnow(),
|
||||
"last_login": datetime.utcnow()
|
||||
}
|
||||
db[USERS_COLLECTION].insert_one(sample_user)
|
||||
logger.info("Added sample user for testing")
|
||||
|
||||
logger.info("Initialization complete. The database is ready for use.")
|
||||
else:
|
||||
logger.error("Failed to initialize MongoDB. See logs for details.")
|
||||
sys.exit(1)
|
||||
133
json_utils.py
Normal file
133
json_utils.py
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
# hp_chatbot/json_utils.py
|
||||
import json
|
||||
import llama_index
|
||||
from llama_index.core.tools import ToolOutput
|
||||
from llama_index.core.agent.react.types import (
|
||||
ActionReasoningStep,
|
||||
ObservationReasoningStep,
|
||||
ResponseReasoningStep,
|
||||
BaseReasoningStep,
|
||||
)
|
||||
from llama_index.core.llms import ChatMessage, LLM, ChatResponse as LlamaResponse
|
||||
from llama_index.core.base.response.schema import Response
|
||||
from flask.json.provider import JSONProvider
|
||||
from bson import ObjectId # Import ObjectId if used in responses/data
|
||||
from datetime import datetime
|
||||
|
||||
class CustomJSONEncoder(json.JSONEncoder):
|
||||
"""
|
||||
Custom JSON Encoder to handle LlamaIndex objects, BSON ObjectId, and other types.
|
||||
"""
|
||||
def default(self, obj):
|
||||
try:
|
||||
# Specific LlamaIndex Types
|
||||
if isinstance(obj, ToolOutput):
|
||||
return {
|
||||
'content': str(obj.content) if obj.content is not None else "",
|
||||
'tool_name': getattr(obj, 'tool_name', None),
|
||||
'raw_output': str(getattr(obj, 'raw_output', None)), # Safely convert raw_output
|
||||
'type': 'tool_output',
|
||||
'metadata': getattr(obj, 'metadata', {})
|
||||
}
|
||||
elif isinstance(obj, (llama_index.core.llms.ChatMessage, ChatMessage)):
|
||||
return {
|
||||
'role': str(obj.role),
|
||||
'content': str(obj.content),
|
||||
'additional_kwargs': obj.additional_kwargs if hasattr(obj, 'additional_kwargs') else {}
|
||||
}
|
||||
elif isinstance(obj, (LlamaResponse, Response)):
|
||||
return {
|
||||
'content': str(getattr(obj, 'response', getattr(obj, 'message', None))),
|
||||
'metadata': getattr(obj, 'metadata', {}),
|
||||
'type': 'llm_response'
|
||||
}
|
||||
elif isinstance(obj, ActionReasoningStep):
|
||||
return {
|
||||
'type': 'action_step',
|
||||
'action': obj.action,
|
||||
'action_input': obj.action_input, # Should be serializable dict
|
||||
'thought': getattr(obj, 'thought', None)
|
||||
}
|
||||
elif isinstance(obj, ObservationReasoningStep):
|
||||
return {
|
||||
'type': 'observation_step',
|
||||
'observation': str(obj.observation), # Ensure observation is string
|
||||
'thought': getattr(obj, 'thought', None)
|
||||
}
|
||||
elif isinstance(obj, ResponseReasoningStep):
|
||||
return {
|
||||
'type': 'response_step',
|
||||
'response': str(obj.response), # Ensure response is string
|
||||
'is_streaming': getattr(obj, 'is_streaming', False),
|
||||
'thought': getattr(obj, 'thought', None)
|
||||
}
|
||||
elif isinstance(obj, BaseReasoningStep): # Catch-all for other steps
|
||||
return {
|
||||
'type': 'base_reasoning_step',
|
||||
'thought': getattr(obj, 'thought', None),
|
||||
'is_done': getattr(obj, 'is_done', False),
|
||||
}
|
||||
# Handle LlamaIndex Document/Node related objects if needed
|
||||
elif isinstance(obj, llama_index.core.schema.Document):
|
||||
return {
|
||||
'doc_id': obj.id_,
|
||||
'text_preview': obj.text[:100] + "..." if obj.text else "",
|
||||
'metadata': obj.metadata, # Metadata should be serializable
|
||||
'type': 'llama_document'
|
||||
}
|
||||
elif isinstance(obj, llama_index.core.schema.NodeWithScore):
|
||||
return {
|
||||
'node': self.default(obj.node), # Recursively serialize the node
|
||||
'score': obj.score,
|
||||
'type': 'node_with_score'
|
||||
}
|
||||
elif isinstance(obj, llama_index.core.schema.TextNode):
|
||||
return {
|
||||
'node_id': obj.id_,
|
||||
'text_preview': obj.text[:100] + "..." if obj.text else "",
|
||||
'metadata': obj.metadata,
|
||||
'type': 'text_node'
|
||||
}
|
||||
|
||||
# Common Python Types
|
||||
elif isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
elif isinstance(obj, ObjectId):
|
||||
return str(obj)
|
||||
elif isinstance(obj, bytes):
|
||||
return "<bytes>" # Or encode to base64 if needed
|
||||
|
||||
# General Fallback for objects with __dict__
|
||||
elif hasattr(obj, '__dict__'):
|
||||
# Filter out private/callable attributes, be cautious with recursion
|
||||
try:
|
||||
d = {k: v for k, v in obj.__dict__.items()
|
||||
if not k.startswith('_') and not callable(v)}
|
||||
# Basic check to prevent deep recursion errors
|
||||
if len(d) > 50: # Arbitrary limit
|
||||
return f"<Complex object type {type(obj).__name__} with keys: {list(d.keys())[:5]}>"
|
||||
return d
|
||||
except Exception:
|
||||
return f"<Unserializable object type {type(obj).__name__} with __dict__>"
|
||||
|
||||
# Final fallback using standard JSON encoding
|
||||
return super().default(obj)
|
||||
|
||||
except Exception as e:
|
||||
# Log the error? Be careful about logging sensitive data
|
||||
# print(f"DEBUG: JSON encoding error for type {type(obj).__name__}: {e}")
|
||||
return f"<Unserializable object of type {type(obj).__name__}>"
|
||||
|
||||
|
||||
class CustomJSONProvider(JSONProvider):
|
||||
"""
|
||||
Flask JSON Provider using the CustomJSONEncoder.
|
||||
"""
|
||||
def dumps(self, obj, **kwargs):
|
||||
kwargs.setdefault('cls', CustomJSONEncoder)
|
||||
kwargs.setdefault('ensure_ascii', False) # Often useful for non-English text
|
||||
kwargs.setdefault('indent', None) # No indent for production APIs
|
||||
return json.dumps(obj, **kwargs)
|
||||
|
||||
def loads(self, s, **kwargs):
|
||||
return json.loads(s, **kwargs)
|
||||
168
main.py
Normal file
168
main.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
# hp_chatbot/main.py
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
from flask import Flask
|
||||
from flask_cors import CORS
|
||||
|
||||
# Ensure the project directory is in the Python path
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
if current_dir not in sys.path:
|
||||
sys.path.insert(0, current_dir)
|
||||
|
||||
# Import necessary components from our modules
|
||||
from config import (
|
||||
APPLICATION_ROOT, MAX_CONTENT_LENGTH,
|
||||
CORS_ALLOWED_ORIGINS, CORS_SUPPORTS_CREDENTIALS,
|
||||
SERVER_HOST, SERVER_PORT, USE_RELOADER, LOG_LEVEL,
|
||||
KEEP_ALIVE_TIMEOUT, READ_TIMEOUT, WRITE_TIMEOUT
|
||||
)
|
||||
from utils import logger, log_structured
|
||||
from json_utils import CustomJSONProvider
|
||||
from ai_core import initialize_global_index # Import the initialization function
|
||||
from shared_state import global_workflow_agent, is_agent_available # Import shared state
|
||||
from routes import register_routes
|
||||
from init_mongodb import init_mongodb # Your MongoDB initialization script
|
||||
|
||||
# --- Flask App Initialization ---
|
||||
app = Flask(__name__)
|
||||
|
||||
# Apply custom JSON provider for handling special types (LlamaIndex objects, etc.)
|
||||
app.json_provider_class = CustomJSONProvider
|
||||
app.json = CustomJSONProvider(app)
|
||||
|
||||
# Configuration
|
||||
app.config['MAX_CONTENT_LENGTH'] = MAX_CONTENT_LENGTH
|
||||
if APPLICATION_ROOT:
|
||||
app.config['APPLICATION_ROOT'] = APPLICATION_ROOT
|
||||
# If using APPLICATION_ROOT, you might need to adjust route prefixes
|
||||
# or use a Blueprint with url_prefix=APPLICATION_ROOT
|
||||
log_structured('info', f"Flask Application Root set to: {APPLICATION_ROOT}")
|
||||
|
||||
# CORS Configuration
|
||||
CORS(app,
|
||||
resources={r"/*": {"origins": CORS_ALLOWED_ORIGINS}},
|
||||
supports_credentials=CORS_SUPPORTS_CREDENTIALS,
|
||||
# Expose custom headers if needed by the frontend
|
||||
# expose_headers=["Content-Disposition"] # Example for downloads
|
||||
)
|
||||
log_structured('info', f"CORS configured for origins: {CORS_ALLOWED_ORIGINS}")
|
||||
|
||||
|
||||
# --- Register Routes ---
|
||||
# Pass the app object to the function in routes.py
|
||||
register_routes(app)
|
||||
log_structured('info', "Flask routes registered.")
|
||||
|
||||
|
||||
# --- Startup Function ---
|
||||
async def startup_event() -> bool:
|
||||
"""Tasks to run when the application starts.
|
||||
|
||||
Returns:
|
||||
bool: True if all startup tasks completed successfully, False otherwise
|
||||
"""
|
||||
log_structured('info', "Application startup sequence initiated.")
|
||||
all_success = True
|
||||
|
||||
# 1. Initialize MongoDB Connection & Schema (using your script)
|
||||
log_structured('info', "Initializing MongoDB connection...")
|
||||
mongo_success = False
|
||||
try:
|
||||
if init_mongodb():
|
||||
log_structured('info', "MongoDB initialized successfully.")
|
||||
mongo_success = True
|
||||
else:
|
||||
log_structured('warning', "MongoDB initialization script finished, but reported issues.")
|
||||
all_success = False
|
||||
except Exception as db_err:
|
||||
log_structured('critical', "FATAL: MongoDB initialization failed.", {'error': str(db_err)})
|
||||
all_success = False
|
||||
# We'll continue in a degraded state
|
||||
|
||||
# 2. Initialize Global AI Index and Agent
|
||||
log_structured('info', "Initializing global AI index and agent...")
|
||||
index_success = await initialize_global_index()
|
||||
|
||||
# Explicitly check the status after initialization
|
||||
if not is_agent_available():
|
||||
log_structured('critical', "After initialize_global_index, global_workflow_agent is still unavailable, even though function may have reported success")
|
||||
all_success = False
|
||||
elif not index_success:
|
||||
log_structured('warning', "AI initialization reported failure, but will continue in degraded state")
|
||||
all_success = False
|
||||
else:
|
||||
log_structured('info', "AI initialization successful, global_workflow_agent is available")
|
||||
|
||||
log_structured('info', f"Application startup sequence complete. Overall success: {all_success}")
|
||||
return all_success
|
||||
|
||||
|
||||
# --- Shutdown Function (Optional) ---
|
||||
async def shutdown_event():
|
||||
"""Tasks to run when the application stops."""
|
||||
log_structured('info', "Application shutdown sequence initiated.")
|
||||
# Add any cleanup tasks here (e.g., closing connections if not handled elsewhere)
|
||||
# Note: Hypercorn might not always guarantee graceful shutdown execution.
|
||||
log_structured('info', "Application shutdown sequence complete.")
|
||||
|
||||
|
||||
# --- Main Execution Block ---
|
||||
if __name__ == '__main__':
|
||||
from hypercorn.config import Config as HypercornConfig
|
||||
from hypercorn.asyncio import serve as hypercorn_serve
|
||||
|
||||
# Create Hypercorn config object
|
||||
config = HypercornConfig()
|
||||
|
||||
# Basic settings
|
||||
config.bind = [f"{SERVER_HOST}:{SERVER_PORT}"]
|
||||
config.use_reloader = USE_RELOADER
|
||||
config.accesslog = '-' # Log to stdout/stderr
|
||||
config.errorlog = '-' # Log to stdout/stderr
|
||||
config.loglevel = LOG_LEVEL.upper()
|
||||
config.worker_class = 'asyncio'
|
||||
|
||||
# Timeouts (ensure these are floats or ints)
|
||||
config.keep_alive_timeout = float(KEEP_ALIVE_TIMEOUT)
|
||||
config.read_timeout = float(READ_TIMEOUT)
|
||||
config.write_timeout = float(WRITE_TIMEOUT)
|
||||
|
||||
# Request size limits (check Hypercorn docs for exact names, might vary slightly)
|
||||
# These might apply to HTTP/1.1 or HTTP/2 differently.
|
||||
# config.h11_max_incomplete_size = MAX_CONTENT_LENGTH # Example for HTTP/1.1
|
||||
# config.h2_max_concurrent_streams = 100 # Example for HTTP/2
|
||||
# config.max_app_buffer_size = MAX_CONTENT_LENGTH # Another potential setting
|
||||
# It's safer to configure these via a reverse proxy (like Nginx) in production.
|
||||
# Hypercorn's defaults are usually reasonable. Let's comment these out for now.
|
||||
|
||||
# Assign startup and shutdown handlers
|
||||
config.startup_hooks = [startup_event]
|
||||
config.shutdown_hooks = [shutdown_event]
|
||||
|
||||
log_structured('info', f"Starting Hypercorn server on {SERVER_HOST}:{SERVER_PORT}")
|
||||
log_structured('info', f"Reload mode: {'Enabled' if USE_RELOADER else 'Disabled'}")
|
||||
|
||||
# Execute startup task before running the server
|
||||
log_structured('info', "Manually executing startup sequence before server start")
|
||||
startup_success = asyncio.run(startup_event())
|
||||
|
||||
# Double-check that the agent is initialized
|
||||
if not is_agent_available():
|
||||
log_structured('critical', "After startup, global_workflow_agent is still unavailable. Forcing re-initialization...")
|
||||
# Try once more to initialize
|
||||
index_success = asyncio.run(initialize_global_index())
|
||||
if not index_success or not is_agent_available():
|
||||
log_structured('critical', "Emergency initialization also failed. Server will run but chat functionality will be impaired.")
|
||||
else:
|
||||
log_structured('info', "Emergency initialization succeeded.")
|
||||
|
||||
# Run the server
|
||||
try:
|
||||
asyncio.run(hypercorn_serve(app, config))
|
||||
except KeyboardInterrupt:
|
||||
log_structured('info', "Server stopped manually (KeyboardInterrupt).")
|
||||
except Exception as run_err:
|
||||
log_structured('critical', "Hypercorn server failed to run.", {'error': str(run_err)})
|
||||
sys.exit(1)
|
||||
458
mongodb_utils.py
Normal file
458
mongodb_utils.py
Normal file
|
|
@ -0,0 +1,458 @@
|
|||
"""
|
||||
MongoDB Utilities for HP Chatbot
|
||||
|
||||
This module provides utility functions for interacting with MongoDB in the HP chatbot application.
|
||||
It includes functions for connecting to MongoDB, and managing users, conversations, and messages.
|
||||
"""
|
||||
|
||||
import pymongo
|
||||
import logging
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
from typing import Dict, List, Optional, Any, Union
|
||||
from bson.objectid import ObjectId
|
||||
import json
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler('mongodb.log')
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# MongoDB connection information
|
||||
MONGO_URI = "mongodb://hp:hp@localhost:27017/?authSource=hp_chatbot"
|
||||
DB_NAME = "hp_chatbot"
|
||||
|
||||
# Collection names
|
||||
USERS_COLLECTION = "users"
|
||||
CONVERSATIONS_COLLECTION = "conversations"
|
||||
MESSAGES_COLLECTION = "messages"
|
||||
|
||||
# Global MongoDB client
|
||||
mongo_client = None
|
||||
db = None
|
||||
|
||||
def get_db():
|
||||
"""Get or initialize the MongoDB database connection."""
|
||||
global mongo_client, db
|
||||
|
||||
if mongo_client is None:
|
||||
try:
|
||||
mongo_client = pymongo.MongoClient(MONGO_URI)
|
||||
mongo_client.admin.command('ping') # Test connection
|
||||
db = mongo_client[DB_NAME]
|
||||
logger.info("Successfully connected to MongoDB")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to connect to MongoDB: {e}")
|
||||
raise
|
||||
|
||||
return db
|
||||
|
||||
def close_connection():
|
||||
"""Close the MongoDB connection."""
|
||||
global mongo_client
|
||||
|
||||
if mongo_client:
|
||||
mongo_client.close()
|
||||
mongo_client = None
|
||||
logger.info("MongoDB connection closed")
|
||||
|
||||
# User functions
|
||||
def get_user_by_username(username: str) -> Optional[Dict]:
|
||||
"""Get a user by username."""
|
||||
try:
|
||||
db = get_db()
|
||||
user = db[USERS_COLLECTION].find_one({"username": username})
|
||||
return user
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user by username: {e}")
|
||||
return None
|
||||
|
||||
def create_or_update_user(username: str, email: Optional[str] = None) -> Optional[str]:
|
||||
"""Create a new user or update an existing one."""
|
||||
try:
|
||||
db = get_db()
|
||||
|
||||
# Check if user exists
|
||||
existing_user = db[USERS_COLLECTION].find_one({"username": username})
|
||||
|
||||
if existing_user:
|
||||
# Update last login
|
||||
db[USERS_COLLECTION].update_one(
|
||||
{"username": username},
|
||||
{"$set": {"last_login": datetime.utcnow()}}
|
||||
)
|
||||
return str(existing_user["_id"])
|
||||
else:
|
||||
# Create new user
|
||||
new_user = {
|
||||
"username": username,
|
||||
"created_at": datetime.utcnow(),
|
||||
"last_login": datetime.utcnow()
|
||||
}
|
||||
|
||||
# Only include email if it's not None to avoid unique constraint issues
|
||||
if email:
|
||||
new_user["email"] = email
|
||||
|
||||
result = db[USERS_COLLECTION].insert_one(new_user)
|
||||
return str(result.inserted_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating or updating user: {e}")
|
||||
# If the error is a duplicate key error, try to find the existing user
|
||||
if "duplicate key error" in str(e) and "username" in str(e):
|
||||
try:
|
||||
existing_user = db[USERS_COLLECTION].find_one({"username": username})
|
||||
if existing_user:
|
||||
return str(existing_user["_id"])
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
# Conversation functions
|
||||
def get_conversation(session_id: str) -> Optional[Dict]:
|
||||
"""Get a conversation by session ID."""
|
||||
try:
|
||||
db = get_db()
|
||||
conversation = db[CONVERSATIONS_COLLECTION].find_one({"session_id": session_id})
|
||||
return conversation
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting conversation: {e}")
|
||||
return None
|
||||
|
||||
def get_conversation_by_id(conversation_id: str) -> Optional[Dict]:
|
||||
"""Get a conversation by its MongoDB ID."""
|
||||
try:
|
||||
db = get_db()
|
||||
# Convert string ID to ObjectId
|
||||
try:
|
||||
obj_id = ObjectId(conversation_id)
|
||||
conversation = db[CONVERSATIONS_COLLECTION].find_one({"_id": obj_id})
|
||||
return conversation
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting conversation ID to ObjectId: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting conversation by ID: {e}")
|
||||
return None
|
||||
|
||||
def get_user_conversations(user_id: str) -> List[Dict]:
|
||||
"""Get all conversations for a user that are not marked as deleted."""
|
||||
try:
|
||||
db = get_db()
|
||||
conversations = list(db[CONVERSATIONS_COLLECTION].find(
|
||||
{
|
||||
"user_id": user_id,
|
||||
# Only return conversations that either don't have is_deleted or have it set to False
|
||||
"$or": [
|
||||
{"is_deleted": {"$exists": False}},
|
||||
{"is_deleted": False}
|
||||
]
|
||||
}
|
||||
).sort("last_updated", pymongo.DESCENDING))
|
||||
return conversations
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting user conversations: {e}")
|
||||
return []
|
||||
|
||||
def create_conversation(session_id: str, user_id: str, title: str = "New conversation") -> Optional[str]:
|
||||
"""Create a new conversation."""
|
||||
try:
|
||||
db = get_db()
|
||||
|
||||
# Check if conversation already exists with this session_id
|
||||
existing = db[CONVERSATIONS_COLLECTION].find_one({"session_id": session_id})
|
||||
if existing:
|
||||
return str(existing["_id"])
|
||||
|
||||
# Create new conversation
|
||||
new_conversation = {
|
||||
"session_id": session_id,
|
||||
"user_id": user_id,
|
||||
"title": title,
|
||||
"created_at": datetime.utcnow(),
|
||||
"last_updated": datetime.utcnow()
|
||||
}
|
||||
result = db[CONVERSATIONS_COLLECTION].insert_one(new_conversation)
|
||||
return str(result.inserted_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating conversation: {e}")
|
||||
return None
|
||||
|
||||
def update_conversation_title(conversation_id: str, title: str) -> bool:
|
||||
"""Update the title of a conversation."""
|
||||
try:
|
||||
db = get_db()
|
||||
db[CONVERSATIONS_COLLECTION].update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{"$set": {"title": title, "last_updated": datetime.utcnow()}}
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating conversation title: {e}")
|
||||
return False
|
||||
|
||||
def update_conversation_timestamp(conversation_id: str) -> bool:
|
||||
"""Update the last_updated timestamp of a conversation."""
|
||||
try:
|
||||
db = get_db()
|
||||
db[CONVERSATIONS_COLLECTION].update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{"$set": {"last_updated": datetime.utcnow()}}
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating conversation timestamp: {e}")
|
||||
return False
|
||||
|
||||
# Message functions
|
||||
def add_message(conversation_id: str, role: str, content: str,
|
||||
sources: Optional[List] = None, reasoning: Optional[List] = None,
|
||||
images: Optional[List] = None) -> Optional[str]:
|
||||
"""Add a message to a conversation."""
|
||||
try:
|
||||
db = get_db()
|
||||
|
||||
# Prepare the message document
|
||||
message = {
|
||||
"conversation_id": conversation_id,
|
||||
"role": role,
|
||||
"content": content,
|
||||
"timestamp": datetime.utcnow()
|
||||
}
|
||||
|
||||
# Add optional fields with serialization
|
||||
if sources:
|
||||
# Serialize sources
|
||||
serialized_sources = json.loads(json.dumps(sources, default=serialize_custom_objects))
|
||||
message["sources"] = serialized_sources
|
||||
|
||||
if reasoning:
|
||||
# Serialize reasoning steps
|
||||
serialized_reasoning = json.loads(json.dumps(reasoning, default=serialize_custom_objects))
|
||||
message["reasoning"] = serialized_reasoning
|
||||
|
||||
if images:
|
||||
# Serialize images
|
||||
serialized_images = json.loads(json.dumps(images, default=serialize_custom_objects))
|
||||
message["images"] = serialized_images
|
||||
|
||||
# Insert the message
|
||||
result = db[MESSAGES_COLLECTION].insert_one(message)
|
||||
|
||||
# Update the conversation timestamp
|
||||
update_conversation_timestamp(conversation_id)
|
||||
|
||||
return str(result.inserted_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding message: {e}")
|
||||
return None
|
||||
|
||||
def serialize_custom_objects(obj):
|
||||
"""
|
||||
Custom serialization function for MongoDB.
|
||||
Handles special types like ActionReasoningStep and other custom classes.
|
||||
"""
|
||||
if hasattr(obj, '__dict__'):
|
||||
# For ActionReasoningStep, ObservationReasoningStep, etc.
|
||||
if obj.__class__.__name__.endswith('ReasoningStep'):
|
||||
result = {
|
||||
'type': obj.__class__.__name__
|
||||
}
|
||||
|
||||
# Add attributes based on the specific type
|
||||
if hasattr(obj, 'action'):
|
||||
result['action'] = obj.action
|
||||
if hasattr(obj, 'action_input'):
|
||||
result['action_input'] = obj.action_input
|
||||
if hasattr(obj, 'observation'):
|
||||
result['observation'] = obj.observation
|
||||
if hasattr(obj, 'response'):
|
||||
result['response'] = obj.response
|
||||
if hasattr(obj, 'thought'):
|
||||
result['thought'] = obj.thought
|
||||
|
||||
return result
|
||||
|
||||
# For other objects with __dict__
|
||||
return {k: v for k, v in obj.__dict__.items()
|
||||
if not k.startswith('_') and not callable(v)}
|
||||
|
||||
# For objects with content property
|
||||
if hasattr(obj, 'content'):
|
||||
return str(obj.content)
|
||||
|
||||
# For objects with string representation
|
||||
try:
|
||||
return str(obj)
|
||||
except:
|
||||
return f"<Unserializable object of type {type(obj).__name__}>"
|
||||
|
||||
def get_conversation_messages(conversation_id: str) -> List[Dict]:
|
||||
"""Get all messages in a conversation."""
|
||||
try:
|
||||
db = get_db()
|
||||
messages = list(db[MESSAGES_COLLECTION].find(
|
||||
{"conversation_id": conversation_id}
|
||||
).sort("timestamp", pymongo.ASCENDING))
|
||||
return messages
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting conversation messages: {e}")
|
||||
return []
|
||||
|
||||
def generate_conversation_title(conversation_id: str, content: List[Dict]) -> Optional[str]:
|
||||
"""
|
||||
Generate a title for a conversation based on its content using AI.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation
|
||||
content: List of messages in the conversation
|
||||
|
||||
Returns:
|
||||
A generated title, or None if generation failed
|
||||
"""
|
||||
try:
|
||||
from llama_index.llms.openai import OpenAI as LlamaOpenAI
|
||||
|
||||
# Extract text from the conversation (first few messages)
|
||||
conversation_text = "\n".join([
|
||||
f"{msg['role']}: {msg['content']}"
|
||||
for msg in content[:5] # Use first 5 messages or fewer
|
||||
])
|
||||
|
||||
# Create LLM instance
|
||||
llm = LlamaOpenAI(
|
||||
model="chatgpt-4o-latest",
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
# Generate title
|
||||
prompt = f"""
|
||||
Based on the following conversation, generate a short, descriptive title (max 5 words):
|
||||
|
||||
{conversation_text}
|
||||
|
||||
Title:
|
||||
"""
|
||||
|
||||
response = llm.complete(prompt)
|
||||
title = response.text.strip()
|
||||
|
||||
# Update the conversation with the new title
|
||||
update_conversation_title(conversation_id, title)
|
||||
|
||||
return title
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating conversation title: {e}")
|
||||
return "New conversation" # Fallback title
|
||||
|
||||
def delete_conversation(conversation_id: str, hard_delete: bool = False) -> bool:
|
||||
"""
|
||||
Delete a conversation and its messages.
|
||||
|
||||
Args:
|
||||
conversation_id: The ID of the conversation to delete
|
||||
hard_delete: If True, physically delete the records; if False, mark as deleted
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
db = get_db()
|
||||
|
||||
if hard_delete:
|
||||
# Permanently delete all messages in the conversation
|
||||
db[MESSAGES_COLLECTION].delete_many({"conversation_id": conversation_id})
|
||||
|
||||
# Permanently delete the conversation
|
||||
db[CONVERSATIONS_COLLECTION].delete_one({"_id": ObjectId(conversation_id)})
|
||||
else:
|
||||
# Mark the conversation as deleted
|
||||
db[CONVERSATIONS_COLLECTION].update_one(
|
||||
{"_id": ObjectId(conversation_id)},
|
||||
{"$set": {"is_deleted": True}}
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Error deleting conversation: {e}")
|
||||
return False
|
||||
|
||||
# Session state management
|
||||
def get_session_state(session_id: str) -> Optional[Dict]:
|
||||
"""
|
||||
Get the session state from MongoDB.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
|
||||
Returns:
|
||||
The session state or None if not found
|
||||
"""
|
||||
try:
|
||||
conversation = get_conversation(session_id)
|
||||
if conversation:
|
||||
# Return a minimal session state
|
||||
return {
|
||||
"initialized": True,
|
||||
"conversation_id": str(conversation["_id"]),
|
||||
"user_id": conversation["user_id"]
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting session state: {e}")
|
||||
return None
|
||||
|
||||
def create_session_state(session_id: str, user_id: str, conversation_id: Optional[str] = None) -> Optional[Dict]:
|
||||
"""
|
||||
Create a new session state in MongoDB.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
user_id: The user ID
|
||||
conversation_id: Optional conversation ID. If not provided, a new conversation will be created.
|
||||
|
||||
Returns:
|
||||
The created session state or None if creation failed
|
||||
"""
|
||||
try:
|
||||
if not conversation_id:
|
||||
conversation_id = create_conversation(session_id, user_id)
|
||||
|
||||
if conversation_id:
|
||||
return {
|
||||
"initialized": True,
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": user_id
|
||||
}
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating session state: {e}")
|
||||
return None
|
||||
|
||||
def update_session_state(session_id: str, state: Dict) -> bool:
|
||||
"""
|
||||
Update the session state in MongoDB.
|
||||
|
||||
Args:
|
||||
session_id: The session ID
|
||||
state: The new state to save
|
||||
|
||||
Returns:
|
||||
True if the update was successful, False otherwise
|
||||
"""
|
||||
try:
|
||||
conversation = get_conversation(session_id)
|
||||
if conversation:
|
||||
# If we need to store additional session state beyond the conversation
|
||||
# we could add a separate collection for that
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating session state: {e}")
|
||||
return False
|
||||
106
requirements.txt
Normal file
106
requirements.txt
Normal file
|
|
@ -0,0 +1,106 @@
|
|||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.11.16
|
||||
aiosignal==1.3.2
|
||||
annotated-types==0.7.0
|
||||
anyio==4.9.0
|
||||
asgiref==3.8.1
|
||||
asyncio==3.4.3
|
||||
attrs==25.3.0
|
||||
banks==2.1.1
|
||||
beautifulsoup4==4.13.3
|
||||
blinker==1.9.0
|
||||
certifi==2025.1.31
|
||||
charset-normalizer==3.4.1
|
||||
click==8.1.8
|
||||
colorama==0.4.6
|
||||
dataclasses-json==0.6.7
|
||||
Deprecated==1.2.18
|
||||
dirtyjson==1.0.8
|
||||
distro==1.9.0
|
||||
dnspython==2.7.0
|
||||
filetype==1.2.0
|
||||
Flask==3.1.0
|
||||
flask-cors==5.0.1
|
||||
frozenlist==1.5.0
|
||||
fsspec==2025.3.2
|
||||
future==1.0.0
|
||||
greenlet==3.1.1
|
||||
griffe==1.7.2
|
||||
h11==0.14.0
|
||||
h2==4.2.0
|
||||
hpack==4.1.0
|
||||
httpcore==1.0.7
|
||||
httpx==0.28.1
|
||||
Hypercorn==0.17.3
|
||||
hyperframe==6.1.0
|
||||
idna==3.10
|
||||
itsdangerous==2.2.0
|
||||
Jinja2==3.1.6
|
||||
jiter==0.9.0
|
||||
joblib==1.4.2
|
||||
llama-cloud==0.1.17
|
||||
llama-cloud-services==0.6.9
|
||||
llama-index==0.12.33
|
||||
llama-index-agent-openai==0.4.6
|
||||
llama-index-cli==0.4.1
|
||||
llama-index-core==0.12.33.post1
|
||||
llama-index-embeddings-openai==0.3.1
|
||||
llama-index-graph-stores-neo4j==0.4.6
|
||||
llama-index-indices-managed-llama-cloud==0.6.11
|
||||
llama-index-llms-openai==0.3.30
|
||||
llama-index-multi-modal-llms-openai==0.4.3
|
||||
llama-index-program-openai==0.3.1
|
||||
llama-index-question-gen-openai==0.3.0
|
||||
llama-index-readers-file==0.4.7
|
||||
llama-index-readers-llama-parse==0.4.0
|
||||
llama-parse==0.6.4.post1
|
||||
lxml==5.3.2
|
||||
markdown2==2.5.3
|
||||
MarkupSafe==3.0.2
|
||||
marshmallow==3.26.1
|
||||
multidict==6.3.2
|
||||
mypy-extensions==1.0.0
|
||||
neo4j==5.28.1
|
||||
nest-asyncio==1.6.0
|
||||
networkx==3.4.2
|
||||
nltk==3.9.1
|
||||
numpy==2.2.4
|
||||
openai==1.71.0
|
||||
packaging==24.2
|
||||
pandas==2.2.3
|
||||
pillow==11.1.0
|
||||
platformdirs==4.3.7
|
||||
priority==2.0.0
|
||||
propcache==0.3.1
|
||||
pydantic==2.11.2
|
||||
pydantic_core==2.33.1
|
||||
pymongo==4.7.0
|
||||
pypdf==5.4.0
|
||||
python-dateutil==2.9.0.post0
|
||||
python-docx==1.1.2
|
||||
python-dotenv==1.1.0
|
||||
python-louvain==0.16
|
||||
pytz==2025.2
|
||||
PyYAML==6.0.2
|
||||
regex==2024.11.6
|
||||
requests==2.32.3
|
||||
setuptools==79.0.0
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
soupsieve==2.6
|
||||
SQLAlchemy==2.0.40
|
||||
striprtf==0.0.26
|
||||
tenacity==9.1.2
|
||||
tiktoken==0.9.0
|
||||
tqdm==4.67.1
|
||||
typing-inspect==0.9.0
|
||||
typing-inspection==0.4.0
|
||||
typing_extensions==4.13.1
|
||||
tzdata==2025.2
|
||||
urllib3==2.3.0
|
||||
uuid==1.30
|
||||
Werkzeug==3.1.3
|
||||
wheel==0.45.1
|
||||
wrapt==1.17.2
|
||||
wsproto==1.2.0
|
||||
yarl==1.19.0
|
||||
162
session_manager.py
Normal file
162
session_manager.py
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
# hp_chatbot/session_manager.py
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
import uuid
|
||||
import time
|
||||
from datetime import datetime
|
||||
|
||||
# Import MongoDB utilities from the separate file
|
||||
from mongodb_utils import (
|
||||
get_db, create_or_update_user, get_user_by_username,
|
||||
create_conversation, get_conversation, get_conversation_by_id, get_user_conversations,
|
||||
get_conversation_messages, add_message, update_conversation_title,
|
||||
generate_conversation_title, delete_conversation,
|
||||
get_session_state as db_get_session_state, # Rename to avoid conflict
|
||||
create_session_state as db_create_session_state,
|
||||
update_session_state as db_update_session_state
|
||||
)
|
||||
|
||||
# Import necessary components from ai_core
|
||||
# Use forward reference for ReActAgent2 if needed, or import normally if load order allows
|
||||
from ai_core import global_workflow_agent, ReActAgent2
|
||||
|
||||
# Import logging
|
||||
from utils import log_structured
|
||||
|
||||
# --- In-memory Session Cache ---
|
||||
# Stores session-specific data like associated conversation_id and user_id
|
||||
# Key: session_id (string), Value: dict {'conversation_id': ObjectId, 'user_id': ObjectId}
|
||||
# Note: The actual agent *instance* is now global (global_workflow_agent),
|
||||
# but its *memory* provides conversation context. Resetting the agent's memory
|
||||
# effectively resets the conversation for that agent instance.
|
||||
# We still need to map a *frontend session ID* to a *persistent conversation ID* in the DB.
|
||||
chat_state: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
def get_or_create_session_state(session_id: str, username: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Gets or creates session state, mapping session_id to user and conversation in MongoDB.
|
||||
Returns a dictionary containing 'conversation_id' and 'user_id'.
|
||||
The 'workflow_agent' is now global and not stored per session here.
|
||||
|
||||
Args:
|
||||
session_id: The unique identifier from the frontend/client.
|
||||
username: Optional username for linking to a user.
|
||||
|
||||
Returns:
|
||||
A dictionary like {'conversation_id': ObjectId, 'user_id': ObjectId, 'initialized': bool}
|
||||
"""
|
||||
global global_workflow_agent # Access the global agent
|
||||
|
||||
# 1. Check in-memory cache first
|
||||
if session_id in chat_state:
|
||||
cached_state = chat_state[session_id]
|
||||
# Ensure it has the necessary keys before returning
|
||||
if 'conversation_id' in cached_state and 'user_id' in cached_state:
|
||||
log_structured('debug', f'Session cache hit for {session_id}', {'cached_state': cached_state})
|
||||
cached_state['initialized'] = global_workflow_agent is not None
|
||||
return cached_state
|
||||
else:
|
||||
log_structured('warning', f'Cached state for {session_id} is incomplete. Re-fetching.', {'cached_state': cached_state})
|
||||
# Remove incomplete entry and proceed to DB check
|
||||
del chat_state[session_id]
|
||||
|
||||
|
||||
# 2. Check persistent storage (MongoDB) for this session_id
|
||||
mongo_session_data = None
|
||||
try:
|
||||
mongo_session_data = db_get_session_state(session_id)
|
||||
except Exception as db_err:
|
||||
log_structured('error', f'Error accessing MongoDB for session {session_id}: {str(db_err)}')
|
||||
# Continue with fallback approach
|
||||
|
||||
user_id = None
|
||||
conversation_id = None
|
||||
|
||||
if mongo_session_data:
|
||||
log_structured('info', f'Loaded existing session state from MongoDB for {session_id}', {
|
||||
'db_data': mongo_session_data # Be careful logging sensitive data
|
||||
})
|
||||
try:
|
||||
user_id = mongo_session_data.get('user_id')
|
||||
conversation_id = mongo_session_data.get('conversation_id')
|
||||
except Exception as parse_err:
|
||||
log_structured('error', f'Error parsing MongoDB session data: {str(parse_err)}')
|
||||
mongo_session_data = None
|
||||
|
||||
# Validate retrieved IDs
|
||||
if not user_id or not conversation_id:
|
||||
log_structured('error', f'Incomplete session data found in DB for {session_id}. Recreating.', {'db_data': mongo_session_data})
|
||||
# Force creation of a new conversation/session state below
|
||||
mongo_session_data = None # Treat as if not found
|
||||
|
||||
# 3. If not found in cache or DB, or if data was invalid, create new state
|
||||
if not mongo_session_data:
|
||||
log_structured('info', f'No valid session state found for {session_id}. Creating new.', {'username': username})
|
||||
|
||||
# Determine User ID
|
||||
effective_username = username if username else f"anonymous_{session_id[:8]}"
|
||||
|
||||
try:
|
||||
user_id = create_or_update_user(effective_username)
|
||||
except Exception as user_err:
|
||||
log_structured('error', f'Failed to create user in MongoDB: {str(user_err)}')
|
||||
# Create a temporary ID for in-memory operation
|
||||
user_id = f"temp_user_{uuid.uuid4().hex}"
|
||||
log_structured('info', f'Using temporary user ID: {user_id}')
|
||||
|
||||
if not user_id:
|
||||
log_structured('error', f"Failed to create or update user: {effective_username}")
|
||||
# Create a fallback user ID for in-memory operation
|
||||
user_id = f"fallback_user_{uuid.uuid4().hex}"
|
||||
log_structured('info', f'Using fallback user ID: {user_id}')
|
||||
|
||||
# Create a new Conversation linked to this user
|
||||
# Use a default title, it will be updated after the first interaction
|
||||
new_conv_title = f"New Chat ({datetime.now().strftime('%Y-%m-%d %H:%M')})"
|
||||
|
||||
try:
|
||||
conversation_id = create_conversation(session_id, user_id, title=new_conv_title)
|
||||
except Exception as conv_err:
|
||||
log_structured('error', f'Failed to create conversation in MongoDB: {str(conv_err)}')
|
||||
# Create a temporary conversation ID for in-memory operation
|
||||
conversation_id = f"temp_conv_{uuid.uuid4().hex}"
|
||||
log_structured('info', f'Using temporary conversation ID: {conversation_id}')
|
||||
|
||||
if not conversation_id:
|
||||
log_structured('error', f"Failed to create conversation for session {session_id}, user {user_id}")
|
||||
# Create a fallback conversation ID for in-memory operation
|
||||
conversation_id = f"fallback_conv_{uuid.uuid4().hex}"
|
||||
log_structured('info', f'Using fallback conversation ID: {conversation_id}')
|
||||
|
||||
# Store the new session state linkage in MongoDB
|
||||
try:
|
||||
db_create_session_state(session_id, user_id, conversation_id)
|
||||
except Exception as session_err:
|
||||
log_structured('warning', f"Failed to persist new session state link in DB: {str(session_err)}")
|
||||
# Continue with in-memory operation
|
||||
|
||||
log_structured('info', f'Created new conversation and session state link', {
|
||||
'session_id': session_id, 'user_id': user_id, 'conversation_id': conversation_id
|
||||
})
|
||||
|
||||
# 4. Store in the in-memory cache and return
|
||||
# We store the DB IDs, not the agent object itself
|
||||
current_state = {
|
||||
'initialized': global_workflow_agent is not None,
|
||||
'conversation_id': conversation_id,
|
||||
'user_id': user_id
|
||||
}
|
||||
chat_state[session_id] = current_state
|
||||
|
||||
return current_state
|
||||
|
||||
def clear_chat_state_cache(session_id: Optional[str] = None):
|
||||
""" Clears the in-memory chat state cache. """
|
||||
global chat_state
|
||||
if session_id:
|
||||
if session_id in chat_state:
|
||||
del chat_state[session_id]
|
||||
log_structured('info', f'Cleared in-memory cache for session {session_id}')
|
||||
else:
|
||||
chat_state = {}
|
||||
log_structured('info', 'Cleared all in-memory session cache')
|
||||
102
shared_state.py
Normal file
102
shared_state.py
Normal file
|
|
@ -0,0 +1,102 @@
|
|||
# hp_chatbot/shared_state.py
|
||||
"""
|
||||
Shared state module to store global variables that need to be
|
||||
accessible across different modules and ensures proper synchronization.
|
||||
"""
|
||||
|
||||
# Store the AI agent here so it's properly shared across modules
|
||||
global_workflow_agent = None
|
||||
|
||||
# Store the global index here
|
||||
global_index = None
|
||||
|
||||
# Store GraphRAG components
|
||||
global_graph_store = None
|
||||
global_property_graph_index = None
|
||||
global_graphrag_query_engine = None
|
||||
|
||||
# Helper to set the global agent
|
||||
def set_global_agent(agent):
|
||||
"""Set the global agent instance."""
|
||||
global global_workflow_agent
|
||||
from utils import log_structured
|
||||
|
||||
if agent is None:
|
||||
log_structured('error', 'Attempted to set global_workflow_agent to None')
|
||||
return False
|
||||
|
||||
try:
|
||||
# Check that the agent has a run method
|
||||
if not hasattr(agent, 'run'):
|
||||
log_structured('error', 'Agent being set does not have a run method')
|
||||
return False
|
||||
|
||||
# Set the global agent
|
||||
global_workflow_agent = agent
|
||||
|
||||
# Verify it was set correctly
|
||||
has_run = hasattr(global_workflow_agent, 'run')
|
||||
success = global_workflow_agent is not None and has_run
|
||||
|
||||
log_structured('info', f'Global agent set successfully: {success}', {
|
||||
'has_run_method': has_run,
|
||||
'agent_type': type(agent).__name__
|
||||
})
|
||||
|
||||
return success
|
||||
except Exception as e:
|
||||
log_structured('error', f'Error setting global agent: {str(e)}')
|
||||
return False
|
||||
|
||||
# Helper to set the global index
|
||||
def set_global_index(index):
|
||||
"""Set the global index instance."""
|
||||
global global_index
|
||||
global_index = index
|
||||
return global_index is not None
|
||||
|
||||
# Helper to set the GraphRAG components
|
||||
def set_graphrag_components(graph_store, property_graph_index, graphrag_query_engine):
|
||||
"""Set the global GraphRAG components."""
|
||||
global global_graph_store, global_property_graph_index, global_graphrag_query_engine
|
||||
|
||||
from utils import log_structured
|
||||
|
||||
global_graph_store = graph_store
|
||||
global_property_graph_index = property_graph_index
|
||||
global_graphrag_query_engine = graphrag_query_engine
|
||||
|
||||
components_set = (global_graph_store is not None and
|
||||
global_property_graph_index is not None and
|
||||
global_graphrag_query_engine is not None)
|
||||
|
||||
log_structured('info', f'GraphRAG components set successfully: {components_set}')
|
||||
return components_set
|
||||
|
||||
# Helper to get agent status
|
||||
def is_agent_available():
|
||||
"""
|
||||
Check if the global agent is available.
|
||||
Uses direct reference to ensure we check the current module state.
|
||||
"""
|
||||
from utils import log_structured
|
||||
|
||||
# Access the module-level global_workflow_agent directly
|
||||
# We are using the global_workflow_agent from this module, not importing it
|
||||
# This avoids circular import issues and ensures we're checking the actual current value
|
||||
|
||||
# IMPORTANT: Declare as global to ensure we're checking the correct module-level variable
|
||||
global global_workflow_agent
|
||||
|
||||
is_available = global_workflow_agent is not None and hasattr(global_workflow_agent, 'run')
|
||||
|
||||
# Add detailed logging
|
||||
if not is_available:
|
||||
if global_workflow_agent is None:
|
||||
log_structured('warning', 'Agent availability check failed: global_workflow_agent is None')
|
||||
elif not hasattr(global_workflow_agent, 'run'):
|
||||
log_structured('warning', 'Agent availability check failed: global_workflow_agent has no run method')
|
||||
else:
|
||||
log_structured('debug', 'Agent availability check passed: agent exists and has run method')
|
||||
|
||||
return is_available
|
||||
178
utils.py
Normal file
178
utils.py
Normal file
|
|
@ -0,0 +1,178 @@
|
|||
# hp_chatbot/utils.py
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Dict, Any
|
||||
from config import LOG_FILE_PATH, CHUNK_FOLDER, UPLOAD_METADATA_FOLDER, ALLOWED_EXTENSIONS
|
||||
from llama_index.core.tools import ToolOutput # Import ToolOutput for serialization check
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, # Set level from config later if needed
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler(),
|
||||
logging.FileHandler(LOG_FILE_PATH)
|
||||
]
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import after logger is defined to avoid circular dependency if CustomJSONEncoder uses logger
|
||||
from json_utils import CustomJSONEncoder
|
||||
|
||||
# --- Logging Helper ---
|
||||
def log_structured(level: str, event_message: str, data: Optional[Dict[str, Any]] = None):
|
||||
"""
|
||||
Structured logging helper
|
||||
Args:
|
||||
level: The logging level ('info', 'error', etc.)
|
||||
event_message: The main message describing the event
|
||||
data: Optional dictionary of additional data to log
|
||||
"""
|
||||
# Basic serializer to handle common non-serializable types safely
|
||||
def safe_serialize(obj):
|
||||
if isinstance(obj, ToolOutput):
|
||||
# Let CustomJSONEncoder handle ToolOutput specifics if passed directly
|
||||
return obj # Pass it through for the encoder
|
||||
if isinstance(obj, (datetime)):
|
||||
return obj.isoformat()
|
||||
if isinstance(obj, bytes):
|
||||
return "<bytes>"
|
||||
if hasattr(obj, '__dict__') and not isinstance(obj, (ToolOutput)): # Avoid double processing ToolOutput
|
||||
# Very basic dict representation, avoid complex objects
|
||||
try:
|
||||
# Limit recursion depth or complexity if needed
|
||||
return {k: safe_serialize(v) for k, v in obj.__dict__.items() if not k.startswith('_') and not callable(v)}
|
||||
except Exception:
|
||||
return f"<Object type {type(obj).__name__}>" # Fallback for complex objects
|
||||
# Add more types as needed (e.g., ObjectId, specific LlamaIndex objects if problematic)
|
||||
return obj # Let JSON encoder handle the rest or fail
|
||||
|
||||
log_data = {
|
||||
'timestamp': datetime.now(timezone.utc).isoformat(),
|
||||
'event': event_message
|
||||
}
|
||||
|
||||
if data is not None:
|
||||
try:
|
||||
# Apply safe serialization recursively to the data dictionary
|
||||
serialized_data = json.loads(json.dumps(data, default=safe_serialize))
|
||||
log_data['data'] = serialized_data
|
||||
except (TypeError, OverflowError, ValueError) as json_err:
|
||||
logger.error(f"Serialization error in log_structured for event '{event_message}': {json_err}. Logging basic info.")
|
||||
log_data['data_serialization_error'] = str(json_err)
|
||||
# Attempt to log basic data structure if possible
|
||||
try:
|
||||
basic_data = {k: str(type(v)) for k, v in data.items()}
|
||||
log_data['data_structure'] = basic_data
|
||||
except Exception:
|
||||
log_data['data_structure'] = "<Could not represent data structure>"
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error during safe_serialize or json.dumps in log_structured: {e}")
|
||||
log_data['logging_error'] = str(e)
|
||||
|
||||
|
||||
try:
|
||||
# Use the custom encoder for final JSON dump
|
||||
log_string = json.dumps(log_data, cls=CustomJSONEncoder)
|
||||
getattr(logger, level.lower())(log_string)
|
||||
except AttributeError:
|
||||
logger.error(f"Invalid log level: {level}. Defaulting to error.")
|
||||
logger.error(json.dumps(log_data, cls=CustomJSONEncoder))
|
||||
except Exception as e:
|
||||
# Fallback logging if JSON fails completely
|
||||
logger.error(f"FATAL: Error serializing log message with CustomJSONEncoder: {e}")
|
||||
fallback_msg = f"{event_message}"
|
||||
if data:
|
||||
fallback_msg += f" | Data keys: {list(data.keys())}"
|
||||
logger.error(fallback_msg)
|
||||
|
||||
|
||||
# --- Response Validation ---
|
||||
def validate_response(response: dict) -> bool:
|
||||
"""Validate the structure of a response"""
|
||||
# Simple validation, adjust as needed
|
||||
required_fields = {'response', 'sources', 'reasoning'}
|
||||
return isinstance(response, dict) and all(field in response for field in required_fields)
|
||||
|
||||
# --- File Handling Utilities (Keep if chunking/upload is potentially needed later) ---
|
||||
def get_upload_metadata(upload_id):
|
||||
"""Load metadata for an upload"""
|
||||
metadata_path = os.path.join(UPLOAD_METADATA_FOLDER, f"{upload_id}.json")
|
||||
if not os.path.exists(metadata_path):
|
||||
return None
|
||||
try:
|
||||
with open(metadata_path, 'r') as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
log_structured('error', f'Failed to load upload metadata for {upload_id}', {'error': str(e)})
|
||||
return None
|
||||
|
||||
def save_upload_metadata(upload_id, metadata):
|
||||
"""Save metadata for an upload"""
|
||||
metadata_path = os.path.join(UPLOAD_METADATA_FOLDER, f"{upload_id}.json")
|
||||
try:
|
||||
with open(metadata_path, 'w') as f:
|
||||
json.dump(metadata, f)
|
||||
except Exception as e:
|
||||
log_structured('error', f'Failed to save upload metadata for {upload_id}', {'error': str(e)})
|
||||
|
||||
def get_chunk_path(upload_id, chunk_index):
|
||||
"""Get path for a specific chunk"""
|
||||
upload_chunk_dir = os.path.join(CHUNK_FOLDER, upload_id)
|
||||
os.makedirs(upload_chunk_dir, exist_ok=True)
|
||||
return os.path.join(upload_chunk_dir, f"chunk_{chunk_index}")
|
||||
|
||||
def combine_chunks(upload_id, destination_path):
|
||||
"""Combine all chunks into a single file"""
|
||||
metadata = get_upload_metadata(upload_id)
|
||||
if not metadata or 'totalChunks' not in metadata:
|
||||
log_structured('error', f'Metadata missing or invalid for combining chunks: {upload_id}')
|
||||
return False
|
||||
|
||||
upload_chunk_dir = os.path.join(CHUNK_FOLDER, upload_id)
|
||||
try:
|
||||
with open(destination_path, 'wb') as outfile:
|
||||
for i in range(metadata['totalChunks']):
|
||||
chunk_path = os.path.join(upload_chunk_dir, f"chunk_{i}")
|
||||
if os.path.exists(chunk_path):
|
||||
with open(chunk_path, 'rb') as infile:
|
||||
outfile.write(infile.read())
|
||||
else:
|
||||
log_structured('error', f'Chunk {i} missing for upload {upload_id}')
|
||||
# Clean up partially created file
|
||||
if os.path.exists(destination_path):
|
||||
os.remove(destination_path)
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
log_structured('error', f'Error combining chunks for {upload_id}', {'error': str(e)})
|
||||
# Clean up partially created file
|
||||
if os.path.exists(destination_path):
|
||||
os.remove(destination_path)
|
||||
return False
|
||||
|
||||
def clear_upload_chunks(upload_id):
|
||||
"""Remove all chunks and metadata for an upload"""
|
||||
upload_chunk_dir = os.path.join(CHUNK_FOLDER, upload_id)
|
||||
if os.path.exists(upload_chunk_dir):
|
||||
try:
|
||||
shutil.rmtree(upload_chunk_dir)
|
||||
log_structured('info', f'Cleared chunk directory: {upload_chunk_dir}')
|
||||
except Exception as e:
|
||||
log_structured('error', f'Failed to remove chunk directory {upload_chunk_dir}', {'error': str(e)})
|
||||
|
||||
metadata_path = os.path.join(UPLOAD_METADATA_FOLDER, f"{upload_id}.json")
|
||||
if os.path.exists(metadata_path):
|
||||
try:
|
||||
os.remove(metadata_path)
|
||||
log_structured('info', f'Cleared metadata file: {metadata_path}')
|
||||
except Exception as e:
|
||||
log_structured('error', f'Failed to remove metadata file {metadata_path}', {'error': str(e)})
|
||||
|
||||
|
||||
def allowed_file(filename):
|
||||
"""Check if a filename has an allowed extension"""
|
||||
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
|
||||
Loading…
Add table
Reference in a new issue