pahvalentines/backend/tasks/workers.py
michael 318a1e1d8d celery: fix invalid Sonauto tags
Only use the music_vibe as the tag - "love song" and "valentine"
are not valid Sonauto API tags.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-31 08:22:59 -06:00

456 lines
14 KiB
Python

"""Celery worker tasks for async processing."""
import json
import logging
import os
import sys
from datetime import datetime, timedelta
from pathlib import Path
import redis
import requests
from celery import shared_task
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.config import settings
from app.database import SessionLocal
from app.models import Submission
logger = logging.getLogger(__name__)
# Redis client for credits caching
redis_client = redis.from_url(settings.REDIS_URL)
def get_cached_credits() -> int:
"""Get cached credits from Redis, default to 0 if not found."""
credits = redis_client.get("sonauto_credits")
return int(credits) if credits else 0
# =============================================================================
# PERIODIC TASKS (Celery Beat)
# =============================================================================
@shared_task(bind=True, name="tasks.workers.process_pending_queue")
def process_pending_queue(self):
"""
Process pending submissions and dispatch them to Sonauto API.
Runs every 60 seconds via Celery Beat.
Respects MAX_CONCURRENT_REQUESTS limit.
"""
# Check credits (BYPASSED FOR TESTING)
# credits = get_cached_credits()
# if credits == 0:
# logger.warning("Credits exhausted - skipping queue processing")
# return {"processed": 0, "reason": "no_credits"}
db = SessionLocal()
try:
# Count active requests (sent but not received)
active_count = (
db.query(Submission)
.filter(
Submission.sent_to_LLM.isnot(None),
Submission.received_from_LLM.is_(None),
)
.count()
)
available_slots = settings.MAX_CONCURRENT_REQUESTS - active_count
if available_slots <= 0:
logger.info(
f"No available slots ({active_count}/{settings.MAX_CONCURRENT_REQUESTS} active)"
)
return {"processed": 0, "reason": "no_slots"}
# Get pending submissions
pending = (
db.query(Submission)
.filter(
Submission.entry_status == "pending",
Submission.sent_to_LLM.is_(None),
Submission.retry_count < settings.MAX_RETRIES,
)
.order_by(Submission.created_at.asc())
.limit(available_slots)
.all()
)
processed = 0
for submission in pending:
# Dispatch individual task for each submission
send_to_sonauto.delay(submission.session_id)
processed += 1
logger.info(f"Dispatched {processed} submissions to Sonauto")
return {"processed": processed}
finally:
db.close()
@shared_task(bind=True, name="tasks.workers.check_credits")
def check_credits(self):
"""
Check Sonauto credits and cache in Redis.
Runs every 10 minutes via Celery Beat.
"""
try:
response = requests.get(
f"{settings.SONAUTO_API_URL}/credits/balance",
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
timeout=30,
)
response.raise_for_status()
data = response.json()
num_credits = data.get("num_credits", 0)
# Cache in Redis (expires in 15 minutes as safety)
redis_client.setex("sonauto_credits", 900, num_credits)
if num_credits == 0:
logger.critical("Credits exhausted - submissions halted")
elif num_credits < settings.MIN_AVAILABLE_CREDITS:
logger.warning(f"Credits below threshold: {num_credits}")
else:
logger.info(f"Credits available: {num_credits}")
return {"credits": num_credits}
except requests.RequestException as e:
logger.error(f"Failed to check credits: {e}")
return {"error": str(e)}
@shared_task(bind=True, name="tasks.workers.check_timeouts")
def check_timeouts(self):
"""
Mark submissions as failed if webhook times out.
Runs every 5 minutes via Celery Beat.
"""
db = SessionLocal()
try:
timeout_threshold = datetime.utcnow() - timedelta(
minutes=settings.WEBHOOK_TIMEOUT_MINUTES
)
stale_submissions = (
db.query(Submission)
.filter(
Submission.sent_to_LLM.isnot(None),
Submission.received_from_LLM.is_(None),
Submission.sent_to_LLM < timeout_threshold,
)
.all()
)
count = 0
for submission in stale_submissions:
submission.entry_status = "fail"
submission.LLM_status = "timeout"
count += 1
logger.warning(f"Submission {submission.session_id} timed out")
db.commit()
logger.info(f"Marked {count} submissions as timed out")
return {"timed_out": count}
finally:
db.close()
@shared_task(bind=True, name="tasks.workers.cleanup_old_files")
def cleanup_old_files(self):
"""
Delete files older than FILE_RETENTION_DAYS.
Runs daily at 3 AM via Celery Beat.
"""
db = SessionLocal()
try:
cutoff_date = datetime.utcnow() - timedelta(days=settings.FILE_RETENTION_DAYS)
old_submissions = (
db.query(Submission).filter(Submission.created_at < cutoff_date).all()
)
deleted_count = 0
for submission in old_submissions:
# Delete associated files
for path_attr in ["photo_path", "generated_song_path", "generated_video_path"]:
file_path = getattr(submission, path_attr)
if file_path:
path = Path(file_path)
if path.exists():
path.unlink()
deleted_count += 1
db.commit()
logger.info(
f"Cleanup: deleted {deleted_count} files from {len(old_submissions)} old submissions"
)
return {"files_deleted": deleted_count, "submissions_processed": len(old_submissions)}
finally:
db.close()
# =============================================================================
# SUBMISSION PROCESSING TASKS
# =============================================================================
@shared_task(
bind=True,
name="tasks.workers.send_to_sonauto",
autoretry_for=(requests.RequestException,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 3},
)
def send_to_sonauto(self, session_id: str):
"""
Send a single submission to Sonauto API.
Called by process_pending_queue for each pending submission.
"""
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
logger.error(f"Submission {session_id} not found")
return {"error": "not_found"}
payload = {
"tags": [submission.music_vibe],
"prompt": (
f"Create a heartfelt Valentine's Day love song celebrating the special bond "
f"between {submission.owner_name} and their beloved {submission.pet_type} "
f"{submission.pet_name}. Make it warm, genuine, and mention both names in the lyrics."
),
"instrumental": False,
"output_format": "mp3",
"webhook_url": f"{settings.WEBHOOK_BASE_URL}/api/webhook",
}
try:
response = requests.post(
f"{settings.SONAUTO_API_URL}/generations/v3",
json=payload,
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
timeout=settings.API_REQUEST_TIMEOUT_SECONDS,
)
response.raise_for_status()
data = response.json()
submission.LLM_task_id = data["task_id"]
submission.sent_to_LLM = datetime.utcnow()
submission.entry_status = "processing"
db.commit()
logger.info(f"Sent {session_id} to Sonauto, task_id: {data['task_id']}")
return {"task_id": data["task_id"]}
except requests.RequestException as e:
# Log the error details for debugging
if hasattr(e, 'response') and e.response is not None:
logger.error(f"Sonauto API error for {session_id}: {e.response.status_code} - {e.response.text}")
logger.error(f"Payload sent: {payload}")
submission.retry_count += 1
if submission.retry_count >= settings.MAX_RETRIES:
submission.entry_status = "fail"
logger.error(
f"Submission {session_id} failed after {settings.MAX_RETRIES} retries"
)
db.commit()
raise # Trigger Celery retry with exponential backoff
finally:
db.close()
# =============================================================================
# POST-WEBHOOK PROCESSING CHAIN
# =============================================================================
@shared_task(
bind=True,
name="tasks.workers.fetch_generation_details",
autoretry_for=(requests.RequestException,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 3},
)
def fetch_generation_details(self, session_id: str) -> dict:
"""
Fetch full generation details from Sonauto API.
First task in the post-webhook processing chain.
"""
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
logger.error(f"Submission {session_id} not found")
raise ValueError(f"Submission {session_id} not found")
response = requests.get(
f"{settings.SONAUTO_API_URL}/generations/{submission.LLM_task_id}",
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
timeout=30,
)
response.raise_for_status()
data = response.json()
# Store full response
submission.LLM_full_response = json.dumps(data)
# Check if generation was successful
api_status = data.get("status", "").upper()
if api_status != "SUCCESS":
submission.LLM_status = "fail"
submission.entry_status = "fail"
db.commit()
logger.error(f"Generation failed for {session_id}: status={api_status}")
raise ValueError(f"Generation failed with status: {api_status}")
# Extract song URL and lyrics
song_paths = data.get("song_paths", [])
if not song_paths:
submission.LLM_status = "fail"
submission.entry_status = "fail"
db.commit()
logger.error(f"No song_paths in response for {session_id}")
raise ValueError("No song_paths in Sonauto response")
submission.LLM_response = song_paths[0]
submission.lyrics = data.get("lyrics", "")
submission.LLM_status = "success"
db.commit()
logger.info(f"Fetched generation details for {session_id}")
# Return data needed by next task in chain
return {
"session_id": session_id,
"song_url": data["song_paths"][0],
"lyrics": data.get("lyrics", ""),
}
finally:
db.close()
@shared_task(
bind=True,
name="tasks.workers.download_audio",
autoretry_for=(requests.RequestException,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 3},
)
def download_audio(self, result: dict) -> dict:
"""
Download MP3 from Sonauto CDN.
Second task in the post-webhook processing chain.
"""
session_id = result["session_id"]
song_url = result["song_url"]
# Ensure storage directory exists
settings.AUDIO_STORAGE.mkdir(parents=True, exist_ok=True)
file_path = settings.AUDIO_STORAGE / f"{session_id}.mp3"
response = requests.get(song_url, stream=True, timeout=120)
response.raise_for_status()
with open(file_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
# Update database
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
submission.generated_song_path = str(file_path)
db.commit()
logger.info(f"Downloaded audio for {session_id}")
finally:
db.close()
# Return data needed by next task in chain
return {
"session_id": session_id,
"audio_path": str(file_path),
"lyrics": result.get("lyrics", ""),
}
@shared_task(
bind=True,
name="tasks.workers.create_video",
autoretry_for=(Exception,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 1}, # Only 1 retry for video generation
)
def create_video(self, result: dict) -> dict:
"""
Create MP4 video combining pet photo with audio.
Final task in the post-webhook processing chain.
Uses the video_generator module.
"""
session_id = result["session_id"]
audio_path = result["audio_path"]
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
raise ValueError(f"Submission {session_id} not found")
# Mark video creation start
submission.video_creation_start = datetime.utcnow()
db.commit()
try:
# Ensure storage directory exists
settings.VIDEO_STORAGE.mkdir(parents=True, exist_ok=True)
output_path = settings.VIDEO_STORAGE / f"{session_id}.mp4"
# Import and call the video generator
from video_generator import create_video as video_gen
video_gen(
pet_img_path=submission.photo_path,
audio_track_path=audio_path,
output_path=str(output_path),
)
# Verify video was created
if not output_path.exists():
raise FileNotFoundError(f"Video file not created at {output_path}")
# Mark success
submission.generated_video_path = str(output_path)
submission.video_creation_end = datetime.utcnow()
submission.entry_status = "success"
db.commit()
logger.info(f"Video created for {session_id}: {output_path}")
return {"session_id": session_id, "video_path": str(output_path)}
except Exception as e:
submission.entry_status = "fail"
submission.video_creation_end = datetime.utcnow()
db.commit()
logger.error(f"Video creation failed for {session_id}: {e}")
raise
finally:
db.close()