"""Celery worker tasks for async processing.""" import json import logging import os import sys from datetime import datetime, timedelta from pathlib import Path import redis import requests from celery import shared_task # Add parent directory to path for imports sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from app.config import settings from app.database import SessionLocal from app.models import Submission from app.prompts import PET_PROMPTS, DEFAULT_PROMPT logger = logging.getLogger(__name__) # Redis client for credits caching redis_client = redis.from_url(settings.REDIS_URL) def get_cached_credits() -> int: """Get cached credits from Redis, default to 0 if not found.""" credits = redis_client.get("sonauto_credits") return int(credits) if credits else 0 # ============================================================================= # PERIODIC TASKS (Celery Beat) # ============================================================================= @shared_task(bind=True, name="tasks.workers.process_pending_queue") def process_pending_queue(self): """ Process pending submissions and dispatch them to Sonauto API. Runs every 60 seconds via Celery Beat. Respects MAX_CONCURRENT_REQUESTS limit. """ # Check credits (BYPASSED FOR TESTING) # credits = get_cached_credits() # if credits == 0: # logger.warning("Credits exhausted - skipping queue processing") # return {"processed": 0, "reason": "no_credits"} db = SessionLocal() try: # Count active requests (sent but not received) active_count = ( db.query(Submission) .filter( Submission.sent_to_LLM.isnot(None), Submission.received_from_LLM.is_(None), ) .count() ) available_slots = settings.MAX_CONCURRENT_REQUESTS - active_count if available_slots <= 0: logger.info( f"No available slots ({active_count}/{settings.MAX_CONCURRENT_REQUESTS} active)" ) return {"processed": 0, "reason": "no_slots"} # Get pending submissions pending = ( db.query(Submission) .filter( Submission.entry_status == "pending", Submission.sent_to_LLM.is_(None), Submission.retry_count < settings.MAX_RETRIES, ) .order_by(Submission.created_at.asc()) .limit(available_slots) .all() ) processed = 0 for submission in pending: # Dispatch individual task for each submission send_to_sonauto.delay(submission.session_id) processed += 1 logger.info(f"Dispatched {processed} submissions to Sonauto") return {"processed": processed} finally: db.close() @shared_task(bind=True, name="tasks.workers.check_credits") def check_credits(self): """ Check Sonauto credits and cache in Redis. Runs every 10 minutes via Celery Beat. """ try: response = requests.get( f"{settings.SONAUTO_API_URL}/credits/balance", headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"}, timeout=30, ) response.raise_for_status() data = response.json() num_credits = data.get("num_credits", 0) # Cache in Redis (expires in 15 minutes as safety) redis_client.setex("sonauto_credits", 900, num_credits) if num_credits == 0: logger.critical("Credits exhausted - submissions halted") elif num_credits < settings.MIN_AVAILABLE_CREDITS: logger.warning(f"Credits below threshold: {num_credits}") else: logger.info(f"Credits available: {num_credits}") return {"credits": num_credits} except requests.RequestException as e: logger.error(f"Failed to check credits: {e}") return {"error": str(e)} @shared_task(bind=True, name="tasks.workers.check_timeouts") def check_timeouts(self): """ Mark submissions as failed if webhook times out. Runs every 5 minutes via Celery Beat. """ db = SessionLocal() try: timeout_threshold = datetime.utcnow() - timedelta( minutes=settings.WEBHOOK_TIMEOUT_MINUTES ) stale_submissions = ( db.query(Submission) .filter( Submission.sent_to_LLM.isnot(None), Submission.received_from_LLM.is_(None), Submission.sent_to_LLM < timeout_threshold, ) .all() ) count = 0 for submission in stale_submissions: submission.entry_status = "fail" submission.LLM_status = "timeout" count += 1 logger.warning(f"Submission {submission.session_id} timed out") db.commit() logger.info(f"Marked {count} submissions as timed out") return {"timed_out": count} finally: db.close() @shared_task(bind=True, name="tasks.workers.cleanup_old_files") def cleanup_old_files(self): """ Delete files older than FILE_RETENTION_DAYS. Runs daily at 3 AM via Celery Beat. """ db = SessionLocal() try: cutoff_date = datetime.utcnow() - timedelta(days=settings.FILE_RETENTION_DAYS) old_submissions = ( db.query(Submission).filter(Submission.created_at < cutoff_date).all() ) deleted_count = 0 for submission in old_submissions: # Delete associated files for path_attr in ["photo_path", "generated_song_path", "generated_video_path"]: file_path = getattr(submission, path_attr) if file_path: path = Path(file_path) if path.exists(): path.unlink() deleted_count += 1 db.commit() logger.info( f"Cleanup: deleted {deleted_count} files from {len(old_submissions)} old submissions" ) return {"files_deleted": deleted_count, "submissions_processed": len(old_submissions)} finally: db.close() # ============================================================================= # SUBMISSION PROCESSING TASKS # ============================================================================= @shared_task( bind=True, name="tasks.workers.send_to_sonauto", autoretry_for=(requests.RequestException,), retry_backoff=True, retry_backoff_max=60, retry_kwargs={"max_retries": 3}, ) def send_to_sonauto(self, session_id: str): """ Send a single submission to Sonauto API. Called by process_pending_queue for each pending submission. """ db = SessionLocal() try: submission = db.get(Submission, session_id) if not submission: logger.error(f"Submission {session_id} not found") return {"error": "not_found"} # Get pet-specific prompt template or fall back to default prompt_template = PET_PROMPTS.get(submission.pet_type, DEFAULT_PROMPT) # Substitute variables prompt = prompt_template.format( owner_name=submission.owner_name, pet_name=submission.pet_name, music_vibe=submission.music_vibe, ) payload = { "prompt": prompt, "instrumental": False, "output_format": "mp3", "webhook_url": f"{settings.WEBHOOK_BASE_URL}/api/webhook", "enable_streaming": True, "stream_format": "mp3", } try: response = requests.post( f"{settings.SONAUTO_API_URL}/generations/v3", json=payload, headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"}, timeout=settings.API_REQUEST_TIMEOUT_SECONDS, ) response.raise_for_status() data = response.json() submission.LLM_task_id = data["task_id"] submission.sent_to_LLM = datetime.utcnow() submission.entry_status = "processing" db.commit() logger.info(f"Sent {session_id} to Sonauto, task_id: {data['task_id']}") return {"task_id": data["task_id"]} except requests.RequestException as e: # Log the error details for debugging if hasattr(e, 'response') and e.response is not None: logger.error(f"Sonauto API error for {session_id}: {e.response.status_code} - {e.response.text}") logger.error(f"Payload sent: {payload}") submission.retry_count += 1 if submission.retry_count >= settings.MAX_RETRIES: submission.entry_status = "fail" logger.error( f"Submission {session_id} failed after {settings.MAX_RETRIES} retries" ) db.commit() raise # Trigger Celery retry with exponential backoff finally: db.close() # ============================================================================= # POST-WEBHOOK PROCESSING CHAIN # ============================================================================= @shared_task( bind=True, name="tasks.workers.fetch_generation_details", autoretry_for=(requests.RequestException,), retry_backoff=True, retry_backoff_max=60, retry_kwargs={"max_retries": 3}, ) def fetch_generation_details(self, session_id: str) -> dict: """ Fetch full generation details from Sonauto API. First task in the post-webhook processing chain. """ db = SessionLocal() try: submission = db.get(Submission, session_id) if not submission: logger.error(f"Submission {session_id} not found") raise ValueError(f"Submission {session_id} not found") response = requests.get( f"{settings.SONAUTO_API_URL}/generations/{submission.LLM_task_id}", headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"}, timeout=30, ) response.raise_for_status() data = response.json() # Store full response submission.LLM_full_response = json.dumps(data) # Check if generation was successful api_status = data.get("status", "").upper() if api_status != "SUCCESS": submission.LLM_status = "fail" submission.entry_status = "fail" db.commit() logger.error(f"Generation failed for {session_id}: status={api_status}") raise ValueError(f"Generation failed with status: {api_status}") # Extract song URL and lyrics song_paths = data.get("song_paths", []) if not song_paths: submission.LLM_status = "fail" submission.entry_status = "fail" db.commit() logger.error(f"No song_paths in response for {session_id}") raise ValueError("No song_paths in Sonauto response") submission.LLM_response = song_paths[0] submission.lyrics = data.get("lyrics", "") submission.LLM_status = "success" db.commit() logger.info(f"Fetched generation details for {session_id}") # Return data needed by next task in chain return { "session_id": session_id, "song_url": data["song_paths"][0], "lyrics": data.get("lyrics", ""), } finally: db.close() @shared_task( bind=True, name="tasks.workers.download_audio", autoretry_for=(requests.RequestException,), retry_backoff=True, retry_backoff_max=60, retry_kwargs={"max_retries": 3}, ) def download_audio(self, result: dict) -> dict: """ Download MP3 from Sonauto CDN. Second task in the post-webhook processing chain. """ session_id = result["session_id"] song_url = result["song_url"] # Ensure storage directory exists settings.AUDIO_STORAGE.mkdir(parents=True, exist_ok=True) file_path = settings.AUDIO_STORAGE / f"{session_id}.mp3" response = requests.get(song_url, stream=True, timeout=120) response.raise_for_status() with open(file_path, "wb") as f: for chunk in response.iter_content(chunk_size=8192): f.write(chunk) # Update database db = SessionLocal() try: submission = db.get(Submission, session_id) submission.generated_song_path = str(file_path) db.commit() logger.info(f"Downloaded audio for {session_id}") finally: db.close() # Return data needed by next task in chain return { "session_id": session_id, "audio_path": str(file_path), "lyrics": result.get("lyrics", ""), } @shared_task( bind=True, name="tasks.workers.create_video", autoretry_for=(Exception,), retry_backoff=True, retry_backoff_max=60, retry_kwargs={"max_retries": 1}, # Only 1 retry for video generation ) def create_video(self, result: dict) -> dict: """ Create MP4 video combining pet photo with audio. Final task in the post-webhook processing chain. Uses the video_generator module. """ session_id = result["session_id"] audio_path = result["audio_path"] db = SessionLocal() try: submission = db.get(Submission, session_id) if not submission: raise ValueError(f"Submission {session_id} not found") # Mark video creation start submission.video_creation_start = datetime.utcnow() db.commit() try: # Ensure storage directory exists settings.VIDEO_STORAGE.mkdir(parents=True, exist_ok=True) output_path = settings.VIDEO_STORAGE / f"{session_id}.mp4" # Import and call the video generator from video_generator import create_video as video_gen video_gen( pet_img_path=submission.photo_path, audio_track_path=audio_path, output_path=str(output_path), ) # Verify video was created if not output_path.exists(): raise FileNotFoundError(f"Video file not created at {output_path}") # Mark success submission.generated_video_path = str(output_path) submission.video_creation_end = datetime.utcnow() submission.entry_status = "success" db.commit() logger.info(f"Video created for {session_id}: {output_path}") return {"session_id": session_id, "video_path": str(output_path)} except Exception as e: submission.entry_status = "fail" submission.video_creation_end = datetime.utcnow() db.commit() logger.error(f"Video creation failed for {session_id}: {e}") raise finally: db.close()