Backend:
- retriever.py: inject user display_name + region into system prompt, remove "ask which country" instructions
- chat.py: pass user_display_name and user_region_code to retriever.query()
- sharepoint_browse.py: fix follow_redirects=True for MS Graph pre-signed download URLs; add webhook endpoint (POST /webhook) with validationToken support and clientState validation; add SharePoint sources CRUD (GET/POST /sources, DELETE/POST /sources/{id}/sync)
- sharepoint.py model: add delta_link column to SharePointSource
- migration 019: ALTER TABLE sharepoint_sources ADD COLUMN delta_link TEXT
- sharepoint_sync.py: persist delta_link after each sync; add sync_all_active_sources Celery task
- celery_app.py: add hourly beat schedule for sync_all_active_sources
- knowledge.py: add PATCH /documents/{id} endpoint to update region_code/department_id/description; re-tags Qdrant vectors via update_document_payload()
- document_processor.py: add update_document_payload() method to set_payload on Qdrant vectors
- admin.py: add region_code to AdminUserUpdateRequest; handle in PATCH /users/{id}; use u.region_code directly in _user_to_response
- schemas/admin.py: add region_code field to AdminUserUpdateRequest
- llm.py: fix str + list crash in stream_with_tools, stream_completion, generate_completion — chunk.content can be a list of content blocks (OpenAI/Anthropic tool responses)
Frontend:
- knowledge-uploader.tsx: add regions/departments props with dropdowns; include region_code + department_id in upload FormData
- admin/page.tsx: fetch regions/departments (super_admin only); pass to KnowledgeUploader; add Region column to documents table with inline popover edit; add SharePoint Sources tab; guard regions/depts fetch to super_admin
- users-tab.tsx: add Region column with dropdown override per user; call PATCH /admin/users/{id} with region_code
- sharepoint-sources-tab.tsx: new component — table of monitored SP folders with Add/Sync/Delete
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
559 lines
20 KiB
Python
559 lines
20 KiB
Python
"""
|
|
Celery Tasks for SharePoint Document Synchronization
|
|
|
|
Tasks:
|
|
sync_sharepoint_source(source_id)
|
|
Full or incremental sync of all documents in a SharePointSource.
|
|
Fetches changed items via Graph Delta API, processes new/updated docs,
|
|
deletes removed docs from Qdrant and DB.
|
|
|
|
process_single_document(drive_id, item_id, source_id, ...)
|
|
Download and index a single SharePoint document. Used for both
|
|
initial sync and webhook-triggered updates.
|
|
|
|
process_webhook_notification(notification_data)
|
|
Handle an incoming webhook payload from Microsoft Graph.
|
|
Dispatches process_single_document for each changed item.
|
|
|
|
renew_expiring_webhooks()
|
|
Periodic task (daily). Renews webhook subscriptions that will
|
|
expire within 7 days to maintain real-time sync.
|
|
|
|
All tasks are async-aware: they wrap async operations with asyncio.run().
|
|
"""
|
|
import asyncio
|
|
import logging
|
|
from datetime import datetime, timedelta
|
|
from typing import Any, Dict, List, Optional
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import select, update
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.document_processor import DocumentProcessor, DocumentProcessingError
|
|
from app.core.sharepoint_client import SharePointGraphClient, GraphAPIError
|
|
from app.database import AsyncSessionLocal
|
|
from app.models.sharepoint import (
|
|
JobStatus,
|
|
JobType,
|
|
SharePointDocument,
|
|
SharePointSource,
|
|
SharePointWebhook,
|
|
SyncJob,
|
|
SyncStatus,
|
|
)
|
|
from celery_app import celery_app
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# =============================================================================
|
|
# Task: Full / Incremental Sync
|
|
# =============================================================================
|
|
|
|
@celery_app.task(
|
|
name="app.tasks.sharepoint_sync.sync_sharepoint_source",
|
|
bind=True,
|
|
max_retries=3,
|
|
default_retry_delay=60,
|
|
autoretry_for=(Exception,),
|
|
retry_backoff=True,
|
|
)
|
|
def sync_sharepoint_source(self, source_id: str) -> Dict[str, Any]:
|
|
"""
|
|
Sync all documents from a configured SharePoint source.
|
|
|
|
Uses Graph Delta API so subsequent runs are incremental (only fetches
|
|
items changed since last sync).
|
|
|
|
Args:
|
|
source_id: UUID string of the SharePointSource record
|
|
|
|
Returns:
|
|
Dict with sync statistics: added, updated, deleted, failed
|
|
"""
|
|
return asyncio.run(_async_sync_sharepoint_source(source_id, self.request.id))
|
|
|
|
|
|
async def _async_sync_sharepoint_source(source_id: str, celery_task_id: str) -> Dict[str, Any]:
|
|
"""Async implementation of sync_sharepoint_source."""
|
|
async with AsyncSessionLocal() as session:
|
|
# Load source
|
|
source = await _get_source(session, source_id)
|
|
if not source:
|
|
logger.error("SharePointSource %s not found", source_id)
|
|
return {"error": "source_not_found"}
|
|
|
|
if not source.is_active:
|
|
logger.info("Source %s is inactive, skipping", source_id)
|
|
return {"skipped": "inactive"}
|
|
|
|
# Create sync job record
|
|
job = SyncJob(
|
|
source_id=UUID(source_id),
|
|
job_type=JobType.INCREMENTAL,
|
|
status=JobStatus.RUNNING,
|
|
celery_task_id=celery_task_id,
|
|
started_at=datetime.utcnow(),
|
|
)
|
|
session.add(job)
|
|
await session.flush() # Get job.id
|
|
|
|
# Update source status
|
|
source.last_sync_status = SyncStatus.SYNCING
|
|
await session.commit()
|
|
|
|
stats = {"added": 0, "updated": 0, "deleted": 0, "failed": 0}
|
|
|
|
try:
|
|
client = SharePointGraphClient()
|
|
processor = DocumentProcessor()
|
|
|
|
# Fetch changed items using delta (incremental sync)
|
|
delta_link = getattr(source, "delta_link", None)
|
|
changes, new_delta_link = await client.get_drive_items_delta(
|
|
drive_id=source.drive_id,
|
|
delta_link=delta_link,
|
|
)
|
|
|
|
logger.info(
|
|
"Source %s: %d changed items to process (delta=%s)",
|
|
source_id, len(changes), "incremental" if delta_link else "initial",
|
|
)
|
|
|
|
for item in changes:
|
|
try:
|
|
await _process_delta_item(
|
|
session=session,
|
|
client=client,
|
|
processor=processor,
|
|
item=item,
|
|
source=source,
|
|
stats=stats,
|
|
)
|
|
except Exception as exc:
|
|
logger.error("Failed to process item %s: %s", item.get("id"), exc)
|
|
stats["failed"] += 1
|
|
|
|
# Persist delta link for next incremental sync
|
|
if new_delta_link:
|
|
source.delta_link = new_delta_link
|
|
|
|
# Update job and source on success
|
|
job.status = JobStatus.COMPLETED
|
|
job.documents_added = stats["added"]
|
|
job.documents_updated = stats["updated"]
|
|
job.documents_deleted = stats["deleted"]
|
|
job.documents_failed = stats["failed"]
|
|
job.completed_at = datetime.utcnow()
|
|
job.duration_seconds = int((job.completed_at - job.started_at).total_seconds())
|
|
|
|
source.last_sync_at = datetime.utcnow()
|
|
source.last_sync_status = SyncStatus.SUCCESS
|
|
source.error_message = None
|
|
await session.commit()
|
|
|
|
logger.info(
|
|
"Sync complete for source %s: +%d ~%d -%d ✗%d",
|
|
source_id, stats["added"], stats["updated"], stats["deleted"], stats["failed"],
|
|
)
|
|
return stats
|
|
|
|
except Exception as exc:
|
|
logger.exception("Sync failed for source %s: %s", source_id, exc)
|
|
job.status = JobStatus.FAILED
|
|
job.error_message = str(exc)
|
|
job.completed_at = datetime.utcnow()
|
|
source.last_sync_status = SyncStatus.ERROR
|
|
source.error_message = str(exc)
|
|
await session.commit()
|
|
raise # Let Celery handle retry
|
|
|
|
|
|
async def _process_delta_item(
|
|
session: AsyncSession,
|
|
client: SharePointGraphClient,
|
|
processor: DocumentProcessor,
|
|
item: Dict,
|
|
source: SharePointSource,
|
|
stats: Dict[str, int],
|
|
) -> None:
|
|
"""
|
|
Process a single changed item from the delta feed.
|
|
|
|
Handles three cases:
|
|
- Item deleted from SharePoint → deactivate in DB and Qdrant
|
|
- Item already in DB → update (re-index)
|
|
- Item new → insert and index
|
|
"""
|
|
item_id = item["id"]
|
|
is_deleted = "deleted" in item
|
|
|
|
# Check if item already exists in DB
|
|
existing_doc = await _get_document_by_sharepoint_id(session, item_id)
|
|
|
|
if is_deleted:
|
|
if existing_doc:
|
|
existing_doc.is_active = False
|
|
processor.deactivate_document(item_id)
|
|
await session.flush()
|
|
stats["deleted"] += 1
|
|
logger.debug("Deleted document %s (%s)", item.get("name"), item_id)
|
|
return
|
|
|
|
# Only process supported file types
|
|
file_info = item.get("file", {})
|
|
file_name = item.get("name", "")
|
|
file_ext = file_name.rsplit(".", 1)[-1].lower() if "." in file_name else ""
|
|
|
|
from app.config import settings as app_settings
|
|
if file_ext not in app_settings.SUPPORTED_FILE_TYPES:
|
|
logger.debug("Skipping unsupported file type '%s': %s", file_ext, file_name)
|
|
return
|
|
|
|
file_size = item.get("size", 0)
|
|
from app.config import settings as app_settings
|
|
if file_size > app_settings.MAX_FILE_SIZE_MB_SHAREPOINT * 1024 * 1024:
|
|
logger.warning("Skipping oversized file %s (%d bytes)", file_name, file_size)
|
|
return
|
|
|
|
# Download file
|
|
file_bytes = await client.download_file(source.drive_id, item_id)
|
|
|
|
# Determine metadata
|
|
last_modified_str = item.get("lastModifiedDateTime", "")
|
|
last_modified = datetime.fromisoformat(last_modified_str.rstrip("Z")) if last_modified_str else datetime.utcnow()
|
|
author = item.get("createdBy", {}).get("user", {}).get("displayName")
|
|
|
|
# Embed and index in Qdrant
|
|
vector_count = await processor.process_document(
|
|
file_bytes=file_bytes,
|
|
file_name=file_name,
|
|
file_type=file_ext,
|
|
sharepoint_id=item_id,
|
|
file_url=item.get("webUrl", ""),
|
|
source_id=str(source.id),
|
|
department_id=str(source.department_id) if source.department_id else None,
|
|
region_code=source.region_code,
|
|
last_modified=last_modified,
|
|
author=author,
|
|
)
|
|
|
|
# Persist or update DB record
|
|
if existing_doc:
|
|
existing_doc.file_name = file_name
|
|
existing_doc.file_size = file_size
|
|
existing_doc.last_modified = last_modified
|
|
existing_doc.vector_count = vector_count
|
|
existing_doc.last_indexed_at = datetime.utcnow()
|
|
existing_doc.is_active = True
|
|
stats["updated"] += 1
|
|
else:
|
|
new_doc = SharePointDocument(
|
|
source_id=source.id,
|
|
sharepoint_id=item_id,
|
|
sharepoint_url=item.get("webUrl", ""),
|
|
file_name=file_name,
|
|
file_path=item.get("parentReference", {}).get("path", "") + "/" + file_name,
|
|
file_type=file_ext,
|
|
file_size=file_size,
|
|
last_modified=last_modified,
|
|
author=author,
|
|
department_id=source.department_id,
|
|
region_code=source.region_code,
|
|
vector_count=vector_count,
|
|
last_indexed_at=datetime.utcnow(),
|
|
)
|
|
session.add(new_doc)
|
|
stats["added"] += 1
|
|
|
|
await session.flush()
|
|
|
|
|
|
# =============================================================================
|
|
# Task: Process Single Document
|
|
# =============================================================================
|
|
|
|
@celery_app.task(
|
|
name="app.tasks.sharepoint_sync.process_single_document",
|
|
bind=True,
|
|
max_retries=3,
|
|
default_retry_delay=30,
|
|
)
|
|
def process_single_document(
|
|
self,
|
|
drive_id: str,
|
|
item_id: str,
|
|
source_id: str,
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Download and index a single SharePoint document.
|
|
|
|
Used for webhook-triggered individual file updates.
|
|
|
|
Args:
|
|
drive_id: MS Graph drive ID
|
|
item_id: MS Graph item ID
|
|
source_id: UUID of the SharePointSource
|
|
|
|
Returns:
|
|
Dict with vector_count or error info
|
|
"""
|
|
return asyncio.run(_async_process_single_document(drive_id, item_id, source_id))
|
|
|
|
|
|
async def _async_process_single_document(
|
|
drive_id: str, item_id: str, source_id: str
|
|
) -> Dict[str, Any]:
|
|
"""Async implementation of process_single_document."""
|
|
async with AsyncSessionLocal() as session:
|
|
source = await _get_source(session, source_id)
|
|
if not source:
|
|
return {"error": "source_not_found"}
|
|
|
|
client = SharePointGraphClient()
|
|
processor = DocumentProcessor()
|
|
|
|
try:
|
|
metadata = await client.get_file_metadata(drive_id, item_id)
|
|
file_name = metadata.get("name", "")
|
|
file_ext = file_name.rsplit(".", 1)[-1].lower() if "." in file_name else ""
|
|
|
|
file_bytes = await client.download_file(drive_id, item_id)
|
|
|
|
last_modified_str = metadata.get("lastModifiedDateTime", "")
|
|
last_modified = (
|
|
datetime.fromisoformat(last_modified_str.rstrip("Z"))
|
|
if last_modified_str
|
|
else datetime.utcnow()
|
|
)
|
|
author = metadata.get("createdBy", {}).get("user", {}).get("displayName")
|
|
|
|
vector_count = await processor.process_document(
|
|
file_bytes=file_bytes,
|
|
file_name=file_name,
|
|
file_type=file_ext,
|
|
sharepoint_id=item_id,
|
|
file_url=metadata.get("webUrl", ""),
|
|
source_id=source_id,
|
|
department_id=str(source.department_id) if source.department_id else None,
|
|
region_code=source.region_code,
|
|
last_modified=last_modified,
|
|
author=author,
|
|
)
|
|
|
|
# Upsert DB record
|
|
existing = await _get_document_by_sharepoint_id(session, item_id)
|
|
if existing:
|
|
existing.vector_count = vector_count
|
|
existing.last_indexed_at = datetime.utcnow()
|
|
else:
|
|
session.add(SharePointDocument(
|
|
source_id=source.id,
|
|
sharepoint_id=item_id,
|
|
sharepoint_url=metadata.get("webUrl", ""),
|
|
file_name=file_name,
|
|
file_path=metadata.get("parentReference", {}).get("path", "") + "/" + file_name,
|
|
file_type=file_ext,
|
|
file_size=metadata.get("size", 0),
|
|
last_modified=last_modified,
|
|
author=author,
|
|
department_id=source.department_id,
|
|
region_code=source.region_code,
|
|
vector_count=vector_count,
|
|
last_indexed_at=datetime.utcnow(),
|
|
))
|
|
|
|
await session.commit()
|
|
return {"item_id": item_id, "file_name": file_name, "vector_count": vector_count}
|
|
|
|
except (GraphAPIError, DocumentProcessingError) as exc:
|
|
logger.error("Failed to process %s: %s", item_id, exc)
|
|
return {"error": str(exc)}
|
|
|
|
|
|
# =============================================================================
|
|
# Task: Handle Webhook Notification
|
|
# =============================================================================
|
|
|
|
@celery_app.task(
|
|
name="app.tasks.sharepoint_sync.process_webhook_notification",
|
|
bind=True,
|
|
max_retries=2,
|
|
default_retry_delay=10,
|
|
)
|
|
def process_webhook_notification(self, notification_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""
|
|
Process an incoming Microsoft Graph change notification.
|
|
|
|
The notification contains subscription_id and resource (drive/items).
|
|
We query Graph for the actual changed items and dispatch
|
|
process_single_document for each one.
|
|
|
|
Args:
|
|
notification_data: Parsed webhook payload dict containing:
|
|
- subscriptionId: Webhook subscription ID
|
|
- clientState: Secret for validation
|
|
- resource: Resource path (e.g. "drives/{id}/root")
|
|
- changeType: "created", "updated", "deleted"
|
|
- resourceData: Optional item data
|
|
|
|
Returns:
|
|
Dict with dispatched task count
|
|
"""
|
|
return asyncio.run(_async_process_webhook_notification(notification_data))
|
|
|
|
|
|
async def _async_process_webhook_notification(notification_data: Dict[str, Any]) -> Dict[str, Any]:
|
|
"""Async implementation of process_webhook_notification."""
|
|
subscription_id = notification_data.get("subscriptionId", "")
|
|
resource = notification_data.get("resource", "")
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
# Find the webhook subscription
|
|
result = await session.execute(
|
|
select(SharePointWebhook).where(
|
|
SharePointWebhook.subscription_id == subscription_id,
|
|
SharePointWebhook.is_active == True,
|
|
)
|
|
)
|
|
webhook = result.scalar_one_or_none()
|
|
|
|
if not webhook:
|
|
logger.warning("No active webhook found for subscription_id=%s", subscription_id)
|
|
return {"error": "webhook_not_found"}
|
|
|
|
# Validate client state to prevent spoofed notifications
|
|
expected_state = webhook.client_state
|
|
received_state = notification_data.get("clientState", "")
|
|
if expected_state and received_state != expected_state:
|
|
logger.warning("Webhook client state mismatch for subscription %s", subscription_id)
|
|
return {"error": "invalid_client_state"}
|
|
|
|
source_id = str(webhook.source_id)
|
|
|
|
# Load source to get drive_id
|
|
source = await _get_source(session, source_id)
|
|
if not source:
|
|
return {"error": "source_not_found"}
|
|
|
|
# Extract drive_id from resource path: "drives/{drive_id}/root"
|
|
# For item-level notifications, resourceData may contain item ID
|
|
resource_data = notification_data.get("resourceData", {})
|
|
item_id = resource_data.get("id")
|
|
|
|
if item_id:
|
|
# Single item notification
|
|
process_single_document.delay(source.drive_id, item_id, source_id)
|
|
return {"dispatched": 1, "item_id": item_id}
|
|
else:
|
|
# Drive-level notification — trigger incremental sync
|
|
sync_sharepoint_source.delay(source_id)
|
|
return {"dispatched": 1, "mode": "incremental_sync"}
|
|
|
|
|
|
# =============================================================================
|
|
# Task: Renew Expiring Webhooks
|
|
# =============================================================================
|
|
|
|
@celery_app.task(name="app.tasks.sharepoint_sync.renew_expiring_webhooks")
|
|
def renew_expiring_webhooks() -> Dict[str, Any]:
|
|
"""
|
|
Renew Microsoft Graph webhook subscriptions expiring within 7 days.
|
|
|
|
Scheduled daily at 03:00 UTC via Celery Beat.
|
|
SharePoint webhooks expire after max 180 days; we renew proactively.
|
|
|
|
Returns:
|
|
Dict with renewed count and any errors
|
|
"""
|
|
return asyncio.run(_async_renew_expiring_webhooks())
|
|
|
|
|
|
async def _async_renew_expiring_webhooks() -> Dict[str, Any]:
|
|
"""Async implementation of renew_expiring_webhooks."""
|
|
renewal_threshold = datetime.utcnow() + timedelta(days=7)
|
|
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
select(SharePointWebhook).where(
|
|
SharePointWebhook.is_active == True,
|
|
SharePointWebhook.expiration_datetime <= renewal_threshold,
|
|
)
|
|
)
|
|
expiring_webhooks = result.scalars().all()
|
|
|
|
if not expiring_webhooks:
|
|
logger.info("No webhooks expiring within 7 days")
|
|
return {"renewed": 0}
|
|
|
|
client = SharePointGraphClient()
|
|
renewed = 0
|
|
errors = []
|
|
|
|
for webhook in expiring_webhooks:
|
|
try:
|
|
updated = await client.renew_webhook(str(webhook.subscription_id), expiration_days=30)
|
|
async with AsyncSessionLocal() as session:
|
|
wh = await session.get(SharePointWebhook, webhook.id)
|
|
if wh:
|
|
wh.expiration_datetime = datetime.fromisoformat(
|
|
updated["expirationDateTime"].rstrip("Z")
|
|
)
|
|
wh.last_renewed_at = datetime.utcnow()
|
|
await session.commit()
|
|
renewed += 1
|
|
logger.info("Renewed webhook %s", webhook.subscription_id)
|
|
except Exception as exc:
|
|
logger.error("Failed to renew webhook %s: %s", webhook.subscription_id, exc)
|
|
errors.append({"subscription_id": str(webhook.subscription_id), "error": str(exc)})
|
|
|
|
return {"renewed": renewed, "errors": errors}
|
|
|
|
|
|
# =============================================================================
|
|
# Task: Sync All Active Sources (Hourly Beat)
|
|
# =============================================================================
|
|
|
|
@celery_app.task(name="app.tasks.sharepoint_sync.sync_all_active_sources")
|
|
def sync_all_active_sources() -> Dict[str, Any]:
|
|
"""
|
|
Hourly Celery Beat task: dispatch sync for all active SharePoint sources.
|
|
"""
|
|
return asyncio.run(_async_sync_all_active_sources())
|
|
|
|
|
|
async def _async_sync_all_active_sources() -> Dict[str, Any]:
|
|
"""Async implementation of sync_all_active_sources."""
|
|
async with AsyncSessionLocal() as session:
|
|
result = await session.execute(
|
|
select(SharePointSource).where(SharePointSource.is_active == True)
|
|
)
|
|
sources = result.scalars().all()
|
|
source_ids = [str(s.id) for s in sources]
|
|
|
|
for source_id in source_ids:
|
|
sync_sharepoint_source.delay(source_id)
|
|
logger.info("Dispatched sync for source %s", source_id)
|
|
|
|
return {"dispatched": len(source_ids)}
|
|
|
|
|
|
# =============================================================================
|
|
# DB helpers
|
|
# =============================================================================
|
|
|
|
async def _get_source(session: AsyncSession, source_id: str) -> Optional[SharePointSource]:
|
|
result = await session.execute(
|
|
select(SharePointSource).where(SharePointSource.id == UUID(source_id))
|
|
)
|
|
return result.scalar_one_or_none()
|
|
|
|
|
|
async def _get_document_by_sharepoint_id(
|
|
session: AsyncSession, sharepoint_id: str
|
|
) -> Optional[SharePointDocument]:
|
|
result = await session.execute(
|
|
select(SharePointDocument).where(SharePointDocument.sharepoint_id == sharepoint_id)
|
|
)
|
|
return result.scalar_one_or_none()
|