semblance/backend/app/services/task_manager.py
2025-12-19 19:26:16 +00:00

228 lines
No EOL
7.3 KiB
Python
Executable file

"""
Task Manager Service for handling cancellable long-running operations.
This service provides a centralized way to track and cancel asyncio tasks
across all generation processes in the application.
"""
import asyncio
import uuid
from typing import Dict, Optional, Any
from datetime import datetime
import logging
logger = logging.getLogger(__name__)
class TaskInfo:
"""Information about a running task."""
def __init__(self, task_id: str, task: asyncio.Task, task_type: str, user_id: str = None, metadata: Dict[str, Any] = None):
self.task_id = task_id
self.task = task
self.task_type = task_type
self.user_id = user_id
self.metadata = metadata or {}
self.created_at = datetime.utcnow()
self.status = "running"
class TaskManager:
"""Singleton service for managing cancellable tasks."""
_instance = None
_lock = asyncio.Lock()
def __new__(cls):
if cls._instance is None:
cls._instance = super(TaskManager, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if not getattr(self, '_initialized', False):
self._tasks: Dict[str, TaskInfo] = {}
self._task_lock = asyncio.Lock()
self._initialized = True
def generate_task_id(self) -> str:
"""Generate a unique task ID."""
return str(uuid.uuid4())
async def register_task(
self,
task: asyncio.Task,
task_type: str,
user_id: str = None,
metadata: Dict[str, Any] = None,
task_id: str = None
) -> str:
"""
Register a new task for tracking and potential cancellation.
Args:
task: The asyncio task to track
task_type: Type of task (e.g., 'persona_generation', 'discussion_guide')
user_id: ID of the user who initiated the task
metadata: Additional metadata about the task
task_id: Optional custom task ID (will generate if not provided)
Returns:
The task ID for tracking
"""
if task_id is None:
task_id = self.generate_task_id()
async with self._task_lock:
task_info = TaskInfo(task_id, task, task_type, user_id, metadata)
self._tasks[task_id] = task_info
# Add callback to clean up when task completes
task.add_done_callback(lambda _: asyncio.create_task(self._cleanup_completed_task(task_id)))
logger.info(f"Registered task {task_id} of type {task_type} for user {user_id}")
return task_id
async def cancel_task(self, task_id: str) -> bool:
"""
Cancel a task by its ID.
Args:
task_id: The ID of the task to cancel
Returns:
True if task was found and cancelled, False otherwise
"""
async with self._task_lock:
task_info = self._tasks.get(task_id)
if not task_info:
logger.warning(f"Task {task_id} not found for cancellation")
return False
if task_info.task.done():
logger.info(f"Task {task_id} already completed")
return False
# Cancel the task
task_info.task.cancel()
task_info.status = "cancelled"
logger.info(f"Cancelled task {task_id} of type {task_info.task_type}")
return True
async def get_task_info(self, task_id: str) -> Optional[TaskInfo]:
"""Get information about a task by its ID."""
async with self._task_lock:
return self._tasks.get(task_id)
async def get_user_tasks(self, user_id: str) -> Dict[str, TaskInfo]:
"""Get all active tasks for a specific user."""
async with self._task_lock:
return {
task_id: task_info
for task_id, task_info in self._tasks.items()
if task_info.user_id == user_id and not task_info.task.done()
}
async def _cleanup_completed_task(self, task_id: str):
"""Internal method to clean up completed tasks."""
async with self._task_lock:
task_info = self._tasks.get(task_id)
if task_info:
logger.info(f"Cleaning up completed task {task_id}")
del self._tasks[task_id]
async def get_active_task_count(self) -> int:
"""Get the number of currently active tasks."""
async with self._task_lock:
return len([t for t in self._tasks.values() if not t.task.done()])
async def cleanup_all_tasks(self):
"""Force cleanup of all tasks (useful for testing/shutdown)."""
async with self._task_lock:
for task_info in self._tasks.values():
if not task_info.task.done():
task_info.task.cancel()
self._tasks.clear()
logger.info("All tasks cleaned up")
# Global instance
task_manager = TaskManager()
def get_task_manager() -> TaskManager:
"""Get the global task manager instance."""
return task_manager
async def register_cancellable_task(
task: asyncio.Task,
task_type: str,
user_id: str = None,
metadata: Dict[str, Any] = None
) -> str:
"""
Convenience function to register a task with the global task manager.
Returns:
The task ID for tracking
"""
return await get_task_manager().register_task(task, task_type, user_id, metadata)
async def cancel_task_by_id(task_id: str) -> bool:
"""
Convenience function to cancel a task by ID.
Returns:
True if task was found and cancelled, False otherwise
"""
return await get_task_manager().cancel_task(task_id)
class CancellableTask:
"""
Context manager for creating cancellable tasks with automatic cleanup.
Usage:
async with CancellableTask("persona_generation", user_id="123") as task_id:
# Your long-running operation here
await some_async_operation()
"""
def __init__(self, task_type: str, user_id: str = None, metadata: Dict[str, Any] = None):
self.task_type = task_type
self.user_id = user_id
self.metadata = metadata
self.task_id = None
async def __aenter__(self):
# Get the current task
current_task = asyncio.current_task()
if current_task:
self.task_id = await register_cancellable_task(
current_task, self.task_type, self.user_id, self.metadata
)
return self.task_id
async def __aexit__(self, exc_type, exc_val, exc_tb):
# Cleanup is handled automatically by the task manager
pass
def check_cancellation():
"""
Decorator to add cancellation checkpoints to functions.
Should be used on functions that have long-running loops or operations.
"""
def decorator(func):
async def wrapper(*args, **kwargs):
# Check if current task is cancelled before proceeding
current_task = asyncio.current_task()
if current_task and current_task.cancelled():
raise asyncio.CancelledError("Task was cancelled")
return await func(*args, **kwargs)
return wrapper
return decorator