pahvalentines/backend/tasks/workers.py
michael 2d9c6ca7a7 refactor: convert polling tasks to event-based architecture
- Remove process_pending_queue and check_credits scheduled tasks
- Add check_submission_timeout task for per-submission timeout handling
- Modify send_to_sonauto to schedule timeout check after successful API call
- Reduce check_timeouts frequency to 30min (safety net only)
- Update submissions endpoint to use apply_async with better error logging
- Remove unused config settings (QUEUE_PROCESSOR_INTERVAL, CREDITS_CHECK_INTERVAL)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 00:28:52 -06:00

420 lines
14 KiB
Python

"""Celery worker tasks for async processing."""
import json
import logging
import os
import sys
from datetime import datetime, timedelta
from pathlib import Path
import requests
from celery import shared_task
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.config import settings
from app.database import SessionLocal
from app.models import Submission
from app.prompts import PET_PROMPTS, DEFAULT_PROMPT
logger = logging.getLogger(__name__)
# =============================================================================
# PERIODIC TASKS (Celery Beat)
# =============================================================================
@shared_task(bind=True, name="tasks.workers.check_timeouts")
def check_timeouts(self):
"""
Safety net: Mark submissions as failed if webhook times out.
Runs every 30 minutes via Celery Beat as a safety net for edge cases
where the per-submission check_submission_timeout task failed to run.
Primary timeout handling is event-based via check_submission_timeout.
"""
db = SessionLocal()
try:
timeout_threshold = datetime.utcnow() - timedelta(
minutes=settings.WEBHOOK_TIMEOUT_MINUTES
)
stale_submissions = (
db.query(Submission)
.filter(
Submission.sent_to_LLM.isnot(None),
Submission.received_from_LLM.is_(None),
Submission.sent_to_LLM < timeout_threshold,
)
.all()
)
count = 0
for submission in stale_submissions:
submission.entry_status = "fail"
submission.LLM_status = "timeout"
count += 1
logger.warning(f"Submission {submission.session_id} timed out")
db.commit()
logger.info(f"Marked {count} submissions as timed out")
return {"timed_out": count}
finally:
db.close()
@shared_task(bind=True, name="tasks.workers.cleanup_old_files")
def cleanup_old_files(self):
"""
Delete files older than FILE_RETENTION_DAYS.
Runs daily at 3 AM via Celery Beat.
"""
db = SessionLocal()
try:
cutoff_date = datetime.utcnow() - timedelta(days=settings.FILE_RETENTION_DAYS)
old_submissions = (
db.query(Submission).filter(Submission.created_at < cutoff_date).all()
)
deleted_count = 0
for submission in old_submissions:
# Delete associated files
for path_attr in ["photo_path", "generated_song_path", "generated_video_path"]:
file_path = getattr(submission, path_attr)
if file_path:
path = Path(file_path)
if path.exists():
path.unlink()
deleted_count += 1
# Delete composite record image if it exists
composite_path = settings.COMPOSITE_STORAGE / f"{submission.session_id}.png"
if composite_path.exists():
composite_path.unlink()
deleted_count += 1
db.commit()
logger.info(
f"Cleanup: deleted {deleted_count} files from {len(old_submissions)} old submissions"
)
return {"files_deleted": deleted_count, "submissions_processed": len(old_submissions)}
finally:
db.close()
# =============================================================================
# SUBMISSION PROCESSING TASKS
# =============================================================================
@shared_task(
bind=True,
name="tasks.workers.send_to_sonauto",
autoretry_for=(requests.RequestException,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 3},
)
def send_to_sonauto(self, session_id: str):
"""
Send a single submission to Sonauto API.
Called immediately after form submission. Schedules a timeout check
after successful API call.
"""
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
logger.error(f"Submission {session_id} not found")
return {"error": "not_found"}
# Get pet-specific prompt template or fall back to default
prompt_template = PET_PROMPTS.get(submission.pet_type, DEFAULT_PROMPT)
# Substitute variables
prompt = prompt_template.format(
owner_name=submission.owner_name,
pet_name=submission.pet_name,
music_vibe=submission.music_vibe,
)
payload = {
"prompt": prompt,
"instrumental": False,
"output_format": "mp3",
"webhook_url": f"{settings.WEBHOOK_BASE_URL}/api/webhook",
"enable_streaming": True,
"stream_format": "mp3",
}
try:
response = requests.post(
f"{settings.SONAUTO_API_URL}/generations/v3",
json=payload,
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
timeout=settings.API_REQUEST_TIMEOUT_SECONDS,
)
response.raise_for_status()
data = response.json()
submission.LLM_task_id = data["task_id"]
submission.sent_to_LLM = datetime.utcnow()
submission.entry_status = "processing"
db.commit()
logger.info(f"Sent {session_id} to Sonauto, task_id: {data['task_id']}")
# Schedule timeout check to fire exactly when timeout expires
check_submission_timeout.apply_async(
args=[session_id],
countdown=settings.WEBHOOK_TIMEOUT_MINUTES * 60,
)
logger.debug(
f"Scheduled timeout check for {session_id} in {settings.WEBHOOK_TIMEOUT_MINUTES} minutes"
)
return {"task_id": data["task_id"]}
except requests.RequestException as e:
# Log the error details for debugging
if hasattr(e, 'response') and e.response is not None:
logger.error(f"Sonauto API error for {session_id}: {e.response.status_code} - {e.response.text}")
logger.error(f"Payload sent: {payload}")
submission.retry_count += 1
if submission.retry_count >= settings.MAX_RETRIES:
submission.entry_status = "fail"
logger.error(
f"Submission {session_id} failed after {settings.MAX_RETRIES} retries"
)
db.commit()
raise # Trigger Celery retry with exponential backoff
finally:
db.close()
@shared_task(bind=True, name="tasks.workers.check_submission_timeout")
def check_submission_timeout(self, session_id: str):
"""
Check if a specific submission has timed out (webhook never arrived).
Scheduled by send_to_sonauto to fire exactly when the timeout expires.
"""
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
logger.warning(f"Timeout check: submission {session_id} not found")
return {"error": "not_found"}
# If webhook already received, nothing to do
if submission.received_from_LLM is not None:
logger.debug(f"Timeout check: {session_id} already received webhook")
return {"status": "webhook_received"}
# If still waiting and in processing state, mark as timed out
if submission.entry_status == "processing":
submission.entry_status = "fail"
submission.LLM_status = "timeout"
db.commit()
logger.warning(f"Submission {session_id} timed out (webhook never arrived)")
return {"status": "timed_out"}
# Already handled (e.g., already failed for another reason)
logger.debug(f"Timeout check: {session_id} already handled (status: {submission.entry_status})")
return {"status": "already_handled"}
finally:
db.close()
# =============================================================================
# POST-WEBHOOK PROCESSING CHAIN
# =============================================================================
@shared_task(
bind=True,
name="tasks.workers.fetch_generation_details",
autoretry_for=(requests.RequestException,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 3},
)
def fetch_generation_details(self, session_id: str) -> dict:
"""
Fetch full generation details from Sonauto API.
First task in the post-webhook processing chain.
"""
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
logger.error(f"Submission {session_id} not found")
raise ValueError(f"Submission {session_id} not found")
response = requests.get(
f"{settings.SONAUTO_API_URL}/generations/{submission.LLM_task_id}",
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
timeout=30,
)
response.raise_for_status()
data = response.json()
# Store full response
submission.LLM_full_response = json.dumps(data)
# Check if generation was successful
api_status = data.get("status", "").upper()
if api_status != "SUCCESS":
submission.LLM_status = "fail"
submission.entry_status = "fail"
db.commit()
logger.error(f"Generation failed for {session_id}: status={api_status}")
raise ValueError(f"Generation failed with status: {api_status}")
# Extract song URL and lyrics
song_paths = data.get("song_paths", [])
if not song_paths:
submission.LLM_status = "fail"
submission.entry_status = "fail"
db.commit()
logger.error(f"No song_paths in response for {session_id}")
raise ValueError("No song_paths in Sonauto response")
submission.LLM_response = song_paths[0]
submission.lyrics = data.get("lyrics", "")
submission.LLM_status = "success"
db.commit()
logger.info(f"Fetched generation details for {session_id}")
# Return data needed by next task in chain
return {
"session_id": session_id,
"song_url": data["song_paths"][0],
"lyrics": data.get("lyrics", ""),
}
finally:
db.close()
@shared_task(
bind=True,
name="tasks.workers.download_audio",
autoretry_for=(requests.RequestException,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 3},
)
def download_audio(self, result: dict) -> dict:
"""
Download MP3 from Sonauto CDN.
Second task in the post-webhook processing chain.
"""
session_id = result["session_id"]
song_url = result["song_url"]
# Ensure storage directory exists
settings.AUDIO_STORAGE.mkdir(parents=True, exist_ok=True)
file_path = settings.AUDIO_STORAGE / f"{session_id}.mp3"
response = requests.get(song_url, stream=True, timeout=120)
response.raise_for_status()
with open(file_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
# Update database
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
submission.generated_song_path = str(file_path)
db.commit()
logger.info(f"Downloaded audio for {session_id}")
finally:
db.close()
# Return data needed by next task in chain
return {
"session_id": session_id,
"audio_path": str(file_path),
"lyrics": result.get("lyrics", ""),
}
@shared_task(
bind=True,
name="tasks.workers.create_video",
autoretry_for=(Exception,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 1}, # Only 1 retry for video generation
)
def create_video(self, result: dict) -> dict:
"""
Create MP4 video combining pet photo with audio.
Final task in the post-webhook processing chain.
Uses the video_generator module.
"""
session_id = result["session_id"]
audio_path = result["audio_path"]
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
raise ValueError(f"Submission {session_id} not found")
# Mark video creation start
submission.video_creation_start = datetime.utcnow()
db.commit()
try:
# Ensure storage directories exist
settings.VIDEO_STORAGE.mkdir(parents=True, exist_ok=True)
settings.COMPOSITE_STORAGE.mkdir(parents=True, exist_ok=True)
output_path = settings.VIDEO_STORAGE / f"{session_id}.mp4"
# Only create composite if user uploaded a photo
composite_output_path = None
if submission.photo_path:
composite_output_path = str(settings.COMPOSITE_STORAGE / f"{session_id}.png")
# Import and call the video generator
from video_generator import create_video as video_gen
video_gen(
pet_img_path=submission.photo_path,
audio_track_path=audio_path,
output_path=str(output_path),
composite_output_path=composite_output_path,
)
# Verify video was created
if not output_path.exists():
raise FileNotFoundError(f"Video file not created at {output_path}")
# Mark success
submission.generated_video_path = str(output_path)
submission.video_creation_end = datetime.utcnow()
submission.entry_status = "success"
db.commit()
logger.info(f"Video created for {session_id}: {output_path}")
return {"session_id": session_id, "video_path": str(output_path)}
except Exception as e:
submission.entry_status = "fail"
submission.video_creation_end = datetime.utcnow()
db.commit()
logger.error(f"Video creation failed for {session_id}: {e}")
raise
finally:
db.close()