Backend:
- task_manager.py: add result/error/completed_at storage, TTL sweeper (5min), store_task_result() helper
- tasks.py: add GET /<task_id> endpoint returning stored result; cancel route stores 'cancelled' status
- __init__.py: start TTL sweeper on app startup
- All 8 bg functions: store result before emitting lightweight WS hint (no payload data)
Frontend:
- src/lib/taskPolling.ts: waitForTaskResult() — polls GET /tasks/{id} every 2s, WS hint triggers immediate poll, 5min timeout
- src/hooks/useTaskPolling.ts: drop-in replacement for useCancellableGeneration using polling
- Migrate 6 Promise-based WS listeners → waitForTaskResult() in DiscussionPanel, FocusGroupSession (×2), PersonaProfile, PersonaModificationModal, useDiscussionGuideGeneration
- Migrate 3 hook-based consumers → useTaskPolling in AIRecruiter, SyntheticUsers, BulkExportProgressModal
Fixes WS Promise leak: polling survives disconnects, background tabs, page reloads.
WS events retained as zero-payload hints for near-zero latency when connected.
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
280 lines
No EOL
9.7 KiB
Python
Executable file
280 lines
No EOL
9.7 KiB
Python
Executable file
"""
|
|
Task Manager Service for handling cancellable long-running operations.
|
|
|
|
This service provides a centralized way to track and cancel asyncio tasks
|
|
across all generation processes in the application.
|
|
"""
|
|
|
|
import asyncio
|
|
import uuid
|
|
from typing import Dict, Optional, Any
|
|
from datetime import datetime, timezone
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TaskInfo:
|
|
"""Information about a running task."""
|
|
|
|
def __init__(self, task_id: str, task: asyncio.Task, task_type: str, user_id: str = None, metadata: Dict[str, Any] = None):
|
|
self.task_id = task_id
|
|
self.task = task
|
|
self.task_type = task_type
|
|
self.user_id = user_id
|
|
self.metadata = metadata or {}
|
|
self.created_at = datetime.now(timezone.utc)
|
|
self.status = "running"
|
|
self.result: Optional[Dict[str, Any]] = None
|
|
self.error: Optional[str] = None
|
|
self.completed_at: Optional[datetime] = None
|
|
|
|
|
|
class TaskManager:
|
|
"""Singleton service for managing cancellable tasks."""
|
|
|
|
_instance = None
|
|
_lock = asyncio.Lock()
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super(TaskManager, cls).__new__(cls)
|
|
cls._instance._initialized = False
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
if not getattr(self, '_initialized', False):
|
|
self._tasks: Dict[str, TaskInfo] = {}
|
|
self._task_lock = asyncio.Lock()
|
|
self._initialized = True
|
|
|
|
def generate_task_id(self) -> str:
|
|
"""Generate a unique task ID."""
|
|
return str(uuid.uuid4())
|
|
|
|
async def register_task(
|
|
self,
|
|
task: asyncio.Task,
|
|
task_type: str,
|
|
user_id: str = None,
|
|
metadata: Dict[str, Any] = None,
|
|
task_id: str = None
|
|
) -> str:
|
|
"""
|
|
Register a new task for tracking and potential cancellation.
|
|
|
|
Args:
|
|
task: The asyncio task to track
|
|
task_type: Type of task (e.g., 'persona_generation', 'discussion_guide')
|
|
user_id: ID of the user who initiated the task
|
|
metadata: Additional metadata about the task
|
|
task_id: Optional custom task ID (will generate if not provided)
|
|
|
|
Returns:
|
|
The task ID for tracking
|
|
"""
|
|
if task_id is None:
|
|
task_id = self.generate_task_id()
|
|
|
|
async with self._task_lock:
|
|
task_info = TaskInfo(task_id, task, task_type, user_id, metadata)
|
|
self._tasks[task_id] = task_info
|
|
|
|
# Add callback to clean up when task completes
|
|
task.add_done_callback(lambda _: asyncio.create_task(self._cleanup_completed_task(task_id)))
|
|
|
|
logger.info(f"Registered task {task_id} of type {task_type} for user {user_id}")
|
|
return task_id
|
|
|
|
async def cancel_task(self, task_id: str) -> bool:
|
|
"""
|
|
Cancel a task by its ID.
|
|
|
|
Args:
|
|
task_id: The ID of the task to cancel
|
|
|
|
Returns:
|
|
True if task was found and cancelled, False otherwise
|
|
"""
|
|
async with self._task_lock:
|
|
task_info = self._tasks.get(task_id)
|
|
if not task_info:
|
|
logger.warning(f"Task {task_id} not found for cancellation")
|
|
return False
|
|
|
|
if task_info.task.done():
|
|
logger.info(f"Task {task_id} already completed")
|
|
return False
|
|
|
|
# Cancel the task
|
|
task_info.task.cancel()
|
|
task_info.status = "cancelled"
|
|
task_info.completed_at = datetime.now(timezone.utc)
|
|
|
|
logger.info(f"Cancelled task {task_id} of type {task_info.task_type}")
|
|
return True
|
|
|
|
async def get_task_info(self, task_id: str) -> Optional[TaskInfo]:
|
|
"""Get information about a task by its ID."""
|
|
async with self._task_lock:
|
|
return self._tasks.get(task_id)
|
|
|
|
async def get_user_tasks(self, user_id: str) -> Dict[str, TaskInfo]:
|
|
"""Get all active tasks for a specific user."""
|
|
async with self._task_lock:
|
|
return {
|
|
task_id: task_info
|
|
for task_id, task_info in self._tasks.items()
|
|
if task_info.user_id == user_id and not task_info.task.done()
|
|
}
|
|
|
|
async def store_result(self, task_id: str, status: str, result: Dict[str, Any] = None, error: str = None):
|
|
"""Store task result for polling. Called by background functions."""
|
|
async with self._task_lock:
|
|
task_info = self._tasks.get(task_id)
|
|
if task_info:
|
|
task_info.status = status # 'completed', 'failed', 'cancelled'
|
|
task_info.result = result
|
|
task_info.error = error
|
|
task_info.completed_at = datetime.now(timezone.utc)
|
|
logger.info(f"Stored result for task {task_id}: status={status}")
|
|
|
|
async def get_task_status_dict(self, task_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Get task status and result for polling endpoint."""
|
|
async with self._task_lock:
|
|
info = self._tasks.get(task_id)
|
|
if not info:
|
|
return None
|
|
return {
|
|
'task_id': info.task_id,
|
|
'status': info.status,
|
|
'task_type': info.task_type,
|
|
'result': info.result,
|
|
'error': info.error,
|
|
'created_at': info.created_at.isoformat(),
|
|
}
|
|
|
|
async def _cleanup_completed_task(self, task_id: str):
|
|
"""Background task done — retain result for polling (TTL sweeper handles cleanup)."""
|
|
logger.info(f"Task {task_id} asyncio Task finished — result retained for polling TTL")
|
|
|
|
RESULT_TTL_SECONDS = 300 # 5 minutes
|
|
|
|
async def sweep_expired_tasks(self):
|
|
"""Remove tasks whose results have been retained past TTL."""
|
|
now = datetime.now(timezone.utc)
|
|
async with self._task_lock:
|
|
expired = [
|
|
tid for tid, info in self._tasks.items()
|
|
if info.completed_at and (now - info.completed_at).total_seconds() > self.RESULT_TTL_SECONDS
|
|
]
|
|
for tid in expired:
|
|
del self._tasks[tid]
|
|
if expired:
|
|
logger.info(f"Swept {len(expired)} expired task results")
|
|
|
|
async def start_sweeper(self):
|
|
"""Background loop to periodically clean up expired task results."""
|
|
while True:
|
|
await asyncio.sleep(60)
|
|
await self.sweep_expired_tasks()
|
|
|
|
async def get_active_task_count(self) -> int:
|
|
"""Get the number of currently active tasks."""
|
|
async with self._task_lock:
|
|
return len([t for t in self._tasks.values() if not t.task.done()])
|
|
|
|
async def cleanup_all_tasks(self):
|
|
"""Force cleanup of all tasks (useful for testing/shutdown)."""
|
|
async with self._task_lock:
|
|
for task_info in self._tasks.values():
|
|
if not task_info.task.done():
|
|
task_info.task.cancel()
|
|
self._tasks.clear()
|
|
logger.info("All tasks cleaned up")
|
|
|
|
|
|
# Global instance
|
|
task_manager = TaskManager()
|
|
|
|
|
|
def get_task_manager() -> TaskManager:
|
|
"""Get the global task manager instance."""
|
|
return task_manager
|
|
|
|
|
|
async def register_cancellable_task(
|
|
task: asyncio.Task,
|
|
task_type: str,
|
|
user_id: str = None,
|
|
metadata: Dict[str, Any] = None
|
|
) -> str:
|
|
"""
|
|
Convenience function to register a task with the global task manager.
|
|
|
|
Returns:
|
|
The task ID for tracking
|
|
"""
|
|
return await get_task_manager().register_task(task, task_type, user_id, metadata)
|
|
|
|
|
|
async def cancel_task_by_id(task_id: str) -> bool:
|
|
"""
|
|
Convenience function to cancel a task by ID.
|
|
|
|
Returns:
|
|
True if task was found and cancelled, False otherwise
|
|
"""
|
|
return await get_task_manager().cancel_task(task_id)
|
|
|
|
|
|
async def store_task_result(task_id: str, status: str, result: Dict[str, Any] = None, error: str = None):
|
|
"""Convenience function to store a task result."""
|
|
await get_task_manager().store_result(task_id, status, result, error)
|
|
|
|
|
|
class CancellableTask:
|
|
"""
|
|
Context manager for creating cancellable tasks with automatic cleanup.
|
|
|
|
Usage:
|
|
async with CancellableTask("persona_generation", user_id="123") as task_id:
|
|
# Your long-running operation here
|
|
await some_async_operation()
|
|
"""
|
|
|
|
def __init__(self, task_type: str, user_id: str = None, metadata: Dict[str, Any] = None):
|
|
self.task_type = task_type
|
|
self.user_id = user_id
|
|
self.metadata = metadata
|
|
self.task_id = None
|
|
|
|
async def __aenter__(self):
|
|
# Get the current task
|
|
current_task = asyncio.current_task()
|
|
if current_task:
|
|
self.task_id = await register_cancellable_task(
|
|
current_task, self.task_type, self.user_id, self.metadata
|
|
)
|
|
return self.task_id
|
|
|
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
|
# Cleanup is handled automatically by the task manager
|
|
pass
|
|
|
|
|
|
def check_cancellation():
|
|
"""
|
|
Decorator to add cancellation checkpoints to functions.
|
|
Should be used on functions that have long-running loops or operations.
|
|
"""
|
|
def decorator(func):
|
|
async def wrapper(*args, **kwargs):
|
|
# Check if current task is cancelled before proceeding
|
|
current_task = asyncio.current_task()
|
|
if current_task and current_task.cancelled():
|
|
raise asyncio.CancelledError("Task was cancelled")
|
|
|
|
return await func(*args, **kwargs)
|
|
return wrapper
|
|
return decorator |