semblance-dev/backend/app/services/task_manager.py
Vadym Samoilenko 1b387daacf Migrate task result delivery from WebSocket to HTTP polling
Backend:
- task_manager.py: add result/error/completed_at storage, TTL sweeper (5min), store_task_result() helper
- tasks.py: add GET /<task_id> endpoint returning stored result; cancel route stores 'cancelled' status
- __init__.py: start TTL sweeper on app startup
- All 8 bg functions: store result before emitting lightweight WS hint (no payload data)

Frontend:
- src/lib/taskPolling.ts: waitForTaskResult() — polls GET /tasks/{id} every 2s, WS hint triggers immediate poll, 5min timeout
- src/hooks/useTaskPolling.ts: drop-in replacement for useCancellableGeneration using polling
- Migrate 6 Promise-based WS listeners → waitForTaskResult() in DiscussionPanel, FocusGroupSession (×2), PersonaProfile, PersonaModificationModal, useDiscussionGuideGeneration
- Migrate 3 hook-based consumers → useTaskPolling in AIRecruiter, SyntheticUsers, BulkExportProgressModal

Fixes WS Promise leak: polling survives disconnects, background tabs, page reloads.
WS events retained as zero-payload hints for near-zero latency when connected.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-23 16:46:58 +00:00

280 lines
No EOL
9.7 KiB
Python
Executable file

"""
Task Manager Service for handling cancellable long-running operations.
This service provides a centralized way to track and cancel asyncio tasks
across all generation processes in the application.
"""
import asyncio
import uuid
from typing import Dict, Optional, Any
from datetime import datetime, timezone
import logging
logger = logging.getLogger(__name__)
class TaskInfo:
"""Information about a running task."""
def __init__(self, task_id: str, task: asyncio.Task, task_type: str, user_id: str = None, metadata: Dict[str, Any] = None):
self.task_id = task_id
self.task = task
self.task_type = task_type
self.user_id = user_id
self.metadata = metadata or {}
self.created_at = datetime.now(timezone.utc)
self.status = "running"
self.result: Optional[Dict[str, Any]] = None
self.error: Optional[str] = None
self.completed_at: Optional[datetime] = None
class TaskManager:
"""Singleton service for managing cancellable tasks."""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super(TaskManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not getattr(self, '_initialized', False):
self._tasks: Dict[str, TaskInfo] = {}
self._task_lock = asyncio.Lock()
self._initialized = True
def generate_task_id(self) -> str:
"""Generate a unique task ID."""
return str(uuid.uuid4())
async def register_task(
self,
task: asyncio.Task,
task_type: str,
user_id: str = None,
metadata: Dict[str, Any] = None,
task_id: str = None
) -> str:
"""
Register a new task for tracking and potential cancellation.
Args:
task: The asyncio task to track
task_type: Type of task (e.g., 'persona_generation', 'discussion_guide')
user_id: ID of the user who initiated the task
metadata: Additional metadata about the task
task_id: Optional custom task ID (will generate if not provided)
Returns:
The task ID for tracking
"""
if task_id is None:
task_id = self.generate_task_id()
async with self._task_lock:
task_info = TaskInfo(task_id, task, task_type, user_id, metadata)
self._tasks[task_id] = task_info
# Add callback to clean up when task completes
task.add_done_callback(lambda _: asyncio.create_task(self._cleanup_completed_task(task_id)))
logger.info(f"Registered task {task_id} of type {task_type} for user {user_id}")
return task_id
async def cancel_task(self, task_id: str) -> bool:
"""
Cancel a task by its ID.
Args:
task_id: The ID of the task to cancel
Returns:
True if task was found and cancelled, False otherwise
"""
async with self._task_lock:
task_info = self._tasks.get(task_id)
if not task_info:
logger.warning(f"Task {task_id} not found for cancellation")
return False
if task_info.task.done():
logger.info(f"Task {task_id} already completed")
return False
# Cancel the task
task_info.task.cancel()
task_info.status = "cancelled"
task_info.completed_at = datetime.now(timezone.utc)
logger.info(f"Cancelled task {task_id} of type {task_info.task_type}")
return True
async def get_task_info(self, task_id: str) -> Optional[TaskInfo]:
"""Get information about a task by its ID."""
async with self._task_lock:
return self._tasks.get(task_id)
async def get_user_tasks(self, user_id: str) -> Dict[str, TaskInfo]:
"""Get all active tasks for a specific user."""
async with self._task_lock:
return {
task_id: task_info
for task_id, task_info in self._tasks.items()
if task_info.user_id == user_id and not task_info.task.done()
}
async def store_result(self, task_id: str, status: str, result: Dict[str, Any] = None, error: str = None):
"""Store task result for polling. Called by background functions."""
async with self._task_lock:
task_info = self._tasks.get(task_id)
if task_info:
task_info.status = status # 'completed', 'failed', 'cancelled'
task_info.result = result
task_info.error = error
task_info.completed_at = datetime.now(timezone.utc)
logger.info(f"Stored result for task {task_id}: status={status}")
async def get_task_status_dict(self, task_id: str) -> Optional[Dict[str, Any]]:
"""Get task status and result for polling endpoint."""
async with self._task_lock:
info = self._tasks.get(task_id)
if not info:
return None
return {
'task_id': info.task_id,
'status': info.status,
'task_type': info.task_type,
'result': info.result,
'error': info.error,
'created_at': info.created_at.isoformat(),
}
async def _cleanup_completed_task(self, task_id: str):
"""Background task done — retain result for polling (TTL sweeper handles cleanup)."""
logger.info(f"Task {task_id} asyncio Task finished — result retained for polling TTL")
RESULT_TTL_SECONDS = 300 # 5 minutes
async def sweep_expired_tasks(self):
"""Remove tasks whose results have been retained past TTL."""
now = datetime.now(timezone.utc)
async with self._task_lock:
expired = [
tid for tid, info in self._tasks.items()
if info.completed_at and (now - info.completed_at).total_seconds() > self.RESULT_TTL_SECONDS
]
for tid in expired:
del self._tasks[tid]
if expired:
logger.info(f"Swept {len(expired)} expired task results")
async def start_sweeper(self):
"""Background loop to periodically clean up expired task results."""
while True:
await asyncio.sleep(60)
await self.sweep_expired_tasks()
async def get_active_task_count(self) -> int:
"""Get the number of currently active tasks."""
async with self._task_lock:
return len([t for t in self._tasks.values() if not t.task.done()])
async def cleanup_all_tasks(self):
"""Force cleanup of all tasks (useful for testing/shutdown)."""
async with self._task_lock:
for task_info in self._tasks.values():
if not task_info.task.done():
task_info.task.cancel()
self._tasks.clear()
logger.info("All tasks cleaned up")
# Global instance
task_manager = TaskManager()
def get_task_manager() -> TaskManager:
"""Get the global task manager instance."""
return task_manager
async def register_cancellable_task(
task: asyncio.Task,
task_type: str,
user_id: str = None,
metadata: Dict[str, Any] = None
) -> str:
"""
Convenience function to register a task with the global task manager.
Returns:
The task ID for tracking
"""
return await get_task_manager().register_task(task, task_type, user_id, metadata)
async def cancel_task_by_id(task_id: str) -> bool:
"""
Convenience function to cancel a task by ID.
Returns:
True if task was found and cancelled, False otherwise
"""
return await get_task_manager().cancel_task(task_id)
async def store_task_result(task_id: str, status: str, result: Dict[str, Any] = None, error: str = None):
"""Convenience function to store a task result."""
await get_task_manager().store_result(task_id, status, result, error)
class CancellableTask:
"""
Context manager for creating cancellable tasks with automatic cleanup.
Usage:
async with CancellableTask("persona_generation", user_id="123") as task_id:
# Your long-running operation here
await some_async_operation()
"""
def __init__(self, task_type: str, user_id: str = None, metadata: Dict[str, Any] = None):
self.task_type = task_type
self.user_id = user_id
self.metadata = metadata
self.task_id = None
async def __aenter__(self):
# Get the current task
current_task = asyncio.current_task()
if current_task:
self.task_id = await register_cancellable_task(
current_task, self.task_type, self.user_id, self.metadata
)
return self.task_id
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Cleanup is handled automatically by the task manager
pass
def check_cancellation():
"""
Decorator to add cancellation checkpoints to functions.
Should be used on functions that have long-running loops or operations.
"""
def decorator(func):
async def wrapper(*args, **kwargs):
# Check if current task is cancelled before proceeding
current_task = asyncio.current_task()
if current_task and current_task.cancelled():
raise asyncio.CancelledError("Task was cancelled")
return await func(*args, **kwargs)
return wrapper
return decorator