""" Task Manager Service for handling cancellable long-running operations. This service provides a centralized way to track and cancel asyncio tasks across all generation processes in the application. """ import asyncio import uuid from typing import Dict, Optional, Any from datetime import datetime, timezone import logging logger = logging.getLogger(__name__) class TaskInfo: """Information about a running task.""" def __init__(self, task_id: str, task: asyncio.Task, task_type: str, user_id: str = None, metadata: Dict[str, Any] = None): self.task_id = task_id self.task = task self.task_type = task_type self.user_id = user_id self.metadata = metadata or {} self.created_at = datetime.now(timezone.utc) self.status = "running" self.result: Optional[Dict[str, Any]] = None self.error: Optional[str] = None self.completed_at: Optional[datetime] = None class TaskManager: """Singleton service for managing cancellable tasks.""" _instance = None _lock = asyncio.Lock() def __new__(cls): if cls._instance is None: cls._instance = super(TaskManager, cls).__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self): if not getattr(self, '_initialized', False): self._tasks: Dict[str, TaskInfo] = {} self._task_lock = asyncio.Lock() self._initialized = True def generate_task_id(self) -> str: """Generate a unique task ID.""" return str(uuid.uuid4()) async def register_task( self, task: asyncio.Task, task_type: str, user_id: str = None, metadata: Dict[str, Any] = None, task_id: str = None ) -> str: """ Register a new task for tracking and potential cancellation. Args: task: The asyncio task to track task_type: Type of task (e.g., 'persona_generation', 'discussion_guide') user_id: ID of the user who initiated the task metadata: Additional metadata about the task task_id: Optional custom task ID (will generate if not provided) Returns: The task ID for tracking """ if task_id is None: task_id = self.generate_task_id() async with self._task_lock: task_info = TaskInfo(task_id, task, task_type, user_id, metadata) self._tasks[task_id] = task_info # Add callback to clean up when task completes task.add_done_callback(lambda _: asyncio.create_task(self._cleanup_completed_task(task_id))) logger.info(f"Registered task {task_id} of type {task_type} for user {user_id}") return task_id async def cancel_task(self, task_id: str) -> bool: """ Cancel a task by its ID. Args: task_id: The ID of the task to cancel Returns: True if task was found and cancelled, False otherwise """ async with self._task_lock: task_info = self._tasks.get(task_id) if not task_info: logger.warning(f"Task {task_id} not found for cancellation") return False if task_info.task.done(): logger.info(f"Task {task_id} already completed") return False # Cancel the task task_info.task.cancel() task_info.status = "cancelled" task_info.completed_at = datetime.now(timezone.utc) logger.info(f"Cancelled task {task_id} of type {task_info.task_type}") return True async def get_task_info(self, task_id: str) -> Optional[TaskInfo]: """Get information about a task by its ID.""" async with self._task_lock: return self._tasks.get(task_id) async def get_user_tasks(self, user_id: str) -> Dict[str, TaskInfo]: """Get all active tasks for a specific user.""" async with self._task_lock: return { task_id: task_info for task_id, task_info in self._tasks.items() if task_info.user_id == user_id and not task_info.task.done() } async def store_result(self, task_id: str, status: str, result: Dict[str, Any] = None, error: str = None): """Store task result for polling. Called by background functions.""" async with self._task_lock: task_info = self._tasks.get(task_id) if task_info: task_info.status = status # 'completed', 'failed', 'cancelled' task_info.result = result task_info.error = error task_info.completed_at = datetime.now(timezone.utc) logger.info(f"Stored result for task {task_id}: status={status}") async def get_task_status_dict(self, task_id: str) -> Optional[Dict[str, Any]]: """Get task status and result for polling endpoint.""" async with self._task_lock: info = self._tasks.get(task_id) if not info: return None return { 'task_id': info.task_id, 'status': info.status, 'task_type': info.task_type, 'result': info.result, 'error': info.error, 'created_at': info.created_at.isoformat(), } async def _cleanup_completed_task(self, task_id: str): """Background task done — retain result for polling (TTL sweeper handles cleanup).""" logger.info(f"Task {task_id} asyncio Task finished — result retained for polling TTL") RESULT_TTL_SECONDS = 300 # 5 minutes async def sweep_expired_tasks(self): """Remove tasks whose results have been retained past TTL.""" now = datetime.now(timezone.utc) async with self._task_lock: expired = [ tid for tid, info in self._tasks.items() if info.completed_at and (now - info.completed_at).total_seconds() > self.RESULT_TTL_SECONDS ] for tid in expired: del self._tasks[tid] if expired: logger.info(f"Swept {len(expired)} expired task results") async def start_sweeper(self): """Background loop to periodically clean up expired task results.""" while True: await asyncio.sleep(60) await self.sweep_expired_tasks() async def get_active_task_count(self) -> int: """Get the number of currently active tasks.""" async with self._task_lock: return len([t for t in self._tasks.values() if not t.task.done()]) async def cleanup_all_tasks(self): """Force cleanup of all tasks (useful for testing/shutdown).""" async with self._task_lock: for task_info in self._tasks.values(): if not task_info.task.done(): task_info.task.cancel() self._tasks.clear() logger.info("All tasks cleaned up") # Global instance task_manager = TaskManager() def get_task_manager() -> TaskManager: """Get the global task manager instance.""" return task_manager async def register_cancellable_task( task: asyncio.Task, task_type: str, user_id: str = None, metadata: Dict[str, Any] = None ) -> str: """ Convenience function to register a task with the global task manager. Returns: The task ID for tracking """ return await get_task_manager().register_task(task, task_type, user_id, metadata) async def cancel_task_by_id(task_id: str) -> bool: """ Convenience function to cancel a task by ID. Returns: True if task was found and cancelled, False otherwise """ return await get_task_manager().cancel_task(task_id) async def store_task_result(task_id: str, status: str, result: Dict[str, Any] = None, error: str = None): """Convenience function to store a task result.""" await get_task_manager().store_result(task_id, status, result, error) class CancellableTask: """ Context manager for creating cancellable tasks with automatic cleanup. Usage: async with CancellableTask("persona_generation", user_id="123") as task_id: # Your long-running operation here await some_async_operation() """ def __init__(self, task_type: str, user_id: str = None, metadata: Dict[str, Any] = None): self.task_type = task_type self.user_id = user_id self.metadata = metadata self.task_id = None async def __aenter__(self): # Get the current task current_task = asyncio.current_task() if current_task: self.task_id = await register_cancellable_task( current_task, self.task_type, self.user_id, self.metadata ) return self.task_id async def __aexit__(self, exc_type, exc_val, exc_tb): # Cleanup is handled automatically by the task manager pass def check_cancellation(): """ Decorator to add cancellation checkpoints to functions. Should be used on functions that have long-running loops or operations. """ def decorator(func): async def wrapper(*args, **kwargs): # Check if current task is cancelled before proceeding current_task = asyncio.current_task() if current_task and current_task.cancelled(): raise asyncio.CancelledError("Task was cancelled") return await func(*args, **kwargs) return wrapper return decorator