Replace the generic song generation prompt with pet-specific prompt templates. Each pet type (Dog, Cat, Fish, Bird, Hamster, Gerbil, Guinea Pig, Rabbit, Bearded Dragon, Leopard Gecko, Corn Snake) now has its own tailored prompt with pet-specific behaviors and content restrictions. Also fix typo: "Beared Dragon" → "Bearded Dragon" in schemas.py and home.js. The prompts.py file includes backwards compatibility for the old spelling to handle any pending database records. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
464 lines
15 KiB
Python
464 lines
15 KiB
Python
"""Celery worker tasks for async processing."""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
|
|
import redis
|
|
import requests
|
|
from celery import shared_task
|
|
|
|
# Add parent directory to path for imports
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from app.config import settings
|
|
from app.database import SessionLocal
|
|
from app.models import Submission
|
|
from app.prompts import PET_PROMPTS, DEFAULT_PROMPT
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Redis client for credits caching
|
|
redis_client = redis.from_url(settings.REDIS_URL)
|
|
|
|
|
|
def get_cached_credits() -> int:
|
|
"""Get cached credits from Redis, default to 0 if not found."""
|
|
credits = redis_client.get("sonauto_credits")
|
|
return int(credits) if credits else 0
|
|
|
|
|
|
# =============================================================================
|
|
# PERIODIC TASKS (Celery Beat)
|
|
# =============================================================================
|
|
|
|
|
|
@shared_task(bind=True, name="tasks.workers.process_pending_queue")
|
|
def process_pending_queue(self):
|
|
"""
|
|
Process pending submissions and dispatch them to Sonauto API.
|
|
|
|
Runs every 60 seconds via Celery Beat.
|
|
Respects MAX_CONCURRENT_REQUESTS limit.
|
|
"""
|
|
# Check credits (BYPASSED FOR TESTING)
|
|
# credits = get_cached_credits()
|
|
# if credits == 0:
|
|
# logger.warning("Credits exhausted - skipping queue processing")
|
|
# return {"processed": 0, "reason": "no_credits"}
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
# Count active requests (sent but not received)
|
|
active_count = (
|
|
db.query(Submission)
|
|
.filter(
|
|
Submission.sent_to_LLM.isnot(None),
|
|
Submission.received_from_LLM.is_(None),
|
|
)
|
|
.count()
|
|
)
|
|
|
|
available_slots = settings.MAX_CONCURRENT_REQUESTS - active_count
|
|
if available_slots <= 0:
|
|
logger.info(
|
|
f"No available slots ({active_count}/{settings.MAX_CONCURRENT_REQUESTS} active)"
|
|
)
|
|
return {"processed": 0, "reason": "no_slots"}
|
|
|
|
# Get pending submissions
|
|
pending = (
|
|
db.query(Submission)
|
|
.filter(
|
|
Submission.entry_status == "pending",
|
|
Submission.sent_to_LLM.is_(None),
|
|
Submission.retry_count < settings.MAX_RETRIES,
|
|
)
|
|
.order_by(Submission.created_at.asc())
|
|
.limit(available_slots)
|
|
.all()
|
|
)
|
|
|
|
processed = 0
|
|
for submission in pending:
|
|
# Dispatch individual task for each submission
|
|
send_to_sonauto.delay(submission.session_id)
|
|
processed += 1
|
|
|
|
logger.info(f"Dispatched {processed} submissions to Sonauto")
|
|
return {"processed": processed}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@shared_task(bind=True, name="tasks.workers.check_credits")
|
|
def check_credits(self):
|
|
"""
|
|
Check Sonauto credits and cache in Redis.
|
|
|
|
Runs every 10 minutes via Celery Beat.
|
|
"""
|
|
try:
|
|
response = requests.get(
|
|
f"{settings.SONAUTO_API_URL}/credits/balance",
|
|
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
|
|
timeout=30,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
num_credits = data.get("num_credits", 0)
|
|
|
|
# Cache in Redis (expires in 15 minutes as safety)
|
|
redis_client.setex("sonauto_credits", 900, num_credits)
|
|
|
|
if num_credits == 0:
|
|
logger.critical("Credits exhausted - submissions halted")
|
|
elif num_credits < settings.MIN_AVAILABLE_CREDITS:
|
|
logger.warning(f"Credits below threshold: {num_credits}")
|
|
else:
|
|
logger.info(f"Credits available: {num_credits}")
|
|
|
|
return {"credits": num_credits}
|
|
|
|
except requests.RequestException as e:
|
|
logger.error(f"Failed to check credits: {e}")
|
|
return {"error": str(e)}
|
|
|
|
|
|
@shared_task(bind=True, name="tasks.workers.check_timeouts")
|
|
def check_timeouts(self):
|
|
"""
|
|
Mark submissions as failed if webhook times out.
|
|
|
|
Runs every 5 minutes via Celery Beat.
|
|
"""
|
|
db = SessionLocal()
|
|
try:
|
|
timeout_threshold = datetime.utcnow() - timedelta(
|
|
minutes=settings.WEBHOOK_TIMEOUT_MINUTES
|
|
)
|
|
|
|
stale_submissions = (
|
|
db.query(Submission)
|
|
.filter(
|
|
Submission.sent_to_LLM.isnot(None),
|
|
Submission.received_from_LLM.is_(None),
|
|
Submission.sent_to_LLM < timeout_threshold,
|
|
)
|
|
.all()
|
|
)
|
|
|
|
count = 0
|
|
for submission in stale_submissions:
|
|
submission.entry_status = "fail"
|
|
submission.LLM_status = "timeout"
|
|
count += 1
|
|
logger.warning(f"Submission {submission.session_id} timed out")
|
|
|
|
db.commit()
|
|
logger.info(f"Marked {count} submissions as timed out")
|
|
return {"timed_out": count}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@shared_task(bind=True, name="tasks.workers.cleanup_old_files")
|
|
def cleanup_old_files(self):
|
|
"""
|
|
Delete files older than FILE_RETENTION_DAYS.
|
|
|
|
Runs daily at 3 AM via Celery Beat.
|
|
"""
|
|
db = SessionLocal()
|
|
try:
|
|
cutoff_date = datetime.utcnow() - timedelta(days=settings.FILE_RETENTION_DAYS)
|
|
|
|
old_submissions = (
|
|
db.query(Submission).filter(Submission.created_at < cutoff_date).all()
|
|
)
|
|
|
|
deleted_count = 0
|
|
for submission in old_submissions:
|
|
# Delete associated files
|
|
for path_attr in ["photo_path", "generated_song_path", "generated_video_path"]:
|
|
file_path = getattr(submission, path_attr)
|
|
if file_path:
|
|
path = Path(file_path)
|
|
if path.exists():
|
|
path.unlink()
|
|
deleted_count += 1
|
|
|
|
db.commit()
|
|
logger.info(
|
|
f"Cleanup: deleted {deleted_count} files from {len(old_submissions)} old submissions"
|
|
)
|
|
return {"files_deleted": deleted_count, "submissions_processed": len(old_submissions)}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# =============================================================================
|
|
# SUBMISSION PROCESSING TASKS
|
|
# =============================================================================
|
|
|
|
|
|
@shared_task(
|
|
bind=True,
|
|
name="tasks.workers.send_to_sonauto",
|
|
autoretry_for=(requests.RequestException,),
|
|
retry_backoff=True,
|
|
retry_backoff_max=60,
|
|
retry_kwargs={"max_retries": 3},
|
|
)
|
|
def send_to_sonauto(self, session_id: str):
|
|
"""
|
|
Send a single submission to Sonauto API.
|
|
|
|
Called by process_pending_queue for each pending submission.
|
|
"""
|
|
db = SessionLocal()
|
|
try:
|
|
submission = db.get(Submission, session_id)
|
|
if not submission:
|
|
logger.error(f"Submission {session_id} not found")
|
|
return {"error": "not_found"}
|
|
|
|
# Get pet-specific prompt template or fall back to default
|
|
prompt_template = PET_PROMPTS.get(submission.pet_type, DEFAULT_PROMPT)
|
|
|
|
# Substitute variables
|
|
prompt = prompt_template.format(
|
|
owner_name=submission.owner_name,
|
|
pet_name=submission.pet_name,
|
|
music_vibe=submission.music_vibe,
|
|
)
|
|
|
|
payload = {
|
|
"prompt": prompt,
|
|
"instrumental": False,
|
|
"output_format": "mp3",
|
|
"webhook_url": f"{settings.WEBHOOK_BASE_URL}/api/webhook",
|
|
"enable_streaming": True,
|
|
"stream_format": "mp3",
|
|
}
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"{settings.SONAUTO_API_URL}/generations/v3",
|
|
json=payload,
|
|
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
|
|
timeout=settings.API_REQUEST_TIMEOUT_SECONDS,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
submission.LLM_task_id = data["task_id"]
|
|
submission.sent_to_LLM = datetime.utcnow()
|
|
submission.entry_status = "processing"
|
|
db.commit()
|
|
|
|
logger.info(f"Sent {session_id} to Sonauto, task_id: {data['task_id']}")
|
|
return {"task_id": data["task_id"]}
|
|
|
|
except requests.RequestException as e:
|
|
# Log the error details for debugging
|
|
if hasattr(e, 'response') and e.response is not None:
|
|
logger.error(f"Sonauto API error for {session_id}: {e.response.status_code} - {e.response.text}")
|
|
logger.error(f"Payload sent: {payload}")
|
|
submission.retry_count += 1
|
|
if submission.retry_count >= settings.MAX_RETRIES:
|
|
submission.entry_status = "fail"
|
|
logger.error(
|
|
f"Submission {session_id} failed after {settings.MAX_RETRIES} retries"
|
|
)
|
|
db.commit()
|
|
raise # Trigger Celery retry with exponential backoff
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# =============================================================================
|
|
# POST-WEBHOOK PROCESSING CHAIN
|
|
# =============================================================================
|
|
|
|
|
|
@shared_task(
|
|
bind=True,
|
|
name="tasks.workers.fetch_generation_details",
|
|
autoretry_for=(requests.RequestException,),
|
|
retry_backoff=True,
|
|
retry_backoff_max=60,
|
|
retry_kwargs={"max_retries": 3},
|
|
)
|
|
def fetch_generation_details(self, session_id: str) -> dict:
|
|
"""
|
|
Fetch full generation details from Sonauto API.
|
|
|
|
First task in the post-webhook processing chain.
|
|
"""
|
|
db = SessionLocal()
|
|
try:
|
|
submission = db.get(Submission, session_id)
|
|
if not submission:
|
|
logger.error(f"Submission {session_id} not found")
|
|
raise ValueError(f"Submission {session_id} not found")
|
|
|
|
response = requests.get(
|
|
f"{settings.SONAUTO_API_URL}/generations/{submission.LLM_task_id}",
|
|
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
|
|
timeout=30,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
# Store full response
|
|
submission.LLM_full_response = json.dumps(data)
|
|
|
|
# Check if generation was successful
|
|
api_status = data.get("status", "").upper()
|
|
if api_status != "SUCCESS":
|
|
submission.LLM_status = "fail"
|
|
submission.entry_status = "fail"
|
|
db.commit()
|
|
logger.error(f"Generation failed for {session_id}: status={api_status}")
|
|
raise ValueError(f"Generation failed with status: {api_status}")
|
|
|
|
# Extract song URL and lyrics
|
|
song_paths = data.get("song_paths", [])
|
|
if not song_paths:
|
|
submission.LLM_status = "fail"
|
|
submission.entry_status = "fail"
|
|
db.commit()
|
|
logger.error(f"No song_paths in response for {session_id}")
|
|
raise ValueError("No song_paths in Sonauto response")
|
|
|
|
submission.LLM_response = song_paths[0]
|
|
submission.lyrics = data.get("lyrics", "")
|
|
submission.LLM_status = "success"
|
|
db.commit()
|
|
|
|
logger.info(f"Fetched generation details for {session_id}")
|
|
|
|
# Return data needed by next task in chain
|
|
return {
|
|
"session_id": session_id,
|
|
"song_url": data["song_paths"][0],
|
|
"lyrics": data.get("lyrics", ""),
|
|
}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@shared_task(
|
|
bind=True,
|
|
name="tasks.workers.download_audio",
|
|
autoretry_for=(requests.RequestException,),
|
|
retry_backoff=True,
|
|
retry_backoff_max=60,
|
|
retry_kwargs={"max_retries": 3},
|
|
)
|
|
def download_audio(self, result: dict) -> dict:
|
|
"""
|
|
Download MP3 from Sonauto CDN.
|
|
|
|
Second task in the post-webhook processing chain.
|
|
"""
|
|
session_id = result["session_id"]
|
|
song_url = result["song_url"]
|
|
|
|
# Ensure storage directory exists
|
|
settings.AUDIO_STORAGE.mkdir(parents=True, exist_ok=True)
|
|
file_path = settings.AUDIO_STORAGE / f"{session_id}.mp3"
|
|
|
|
response = requests.get(song_url, stream=True, timeout=120)
|
|
response.raise_for_status()
|
|
|
|
with open(file_path, "wb") as f:
|
|
for chunk in response.iter_content(chunk_size=8192):
|
|
f.write(chunk)
|
|
|
|
# Update database
|
|
db = SessionLocal()
|
|
try:
|
|
submission = db.get(Submission, session_id)
|
|
submission.generated_song_path = str(file_path)
|
|
db.commit()
|
|
logger.info(f"Downloaded audio for {session_id}")
|
|
finally:
|
|
db.close()
|
|
|
|
# Return data needed by next task in chain
|
|
return {
|
|
"session_id": session_id,
|
|
"audio_path": str(file_path),
|
|
"lyrics": result.get("lyrics", ""),
|
|
}
|
|
|
|
|
|
@shared_task(
|
|
bind=True,
|
|
name="tasks.workers.create_video",
|
|
autoretry_for=(Exception,),
|
|
retry_backoff=True,
|
|
retry_backoff_max=60,
|
|
retry_kwargs={"max_retries": 1}, # Only 1 retry for video generation
|
|
)
|
|
def create_video(self, result: dict) -> dict:
|
|
"""
|
|
Create MP4 video combining pet photo with audio.
|
|
|
|
Final task in the post-webhook processing chain.
|
|
Uses the video_generator module.
|
|
"""
|
|
session_id = result["session_id"]
|
|
audio_path = result["audio_path"]
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
submission = db.get(Submission, session_id)
|
|
if not submission:
|
|
raise ValueError(f"Submission {session_id} not found")
|
|
|
|
# Mark video creation start
|
|
submission.video_creation_start = datetime.utcnow()
|
|
db.commit()
|
|
|
|
try:
|
|
# Ensure storage directory exists
|
|
settings.VIDEO_STORAGE.mkdir(parents=True, exist_ok=True)
|
|
output_path = settings.VIDEO_STORAGE / f"{session_id}.mp4"
|
|
|
|
# Import and call the video generator
|
|
from video_generator import create_video as video_gen
|
|
|
|
video_gen(
|
|
pet_img_path=submission.photo_path,
|
|
audio_track_path=audio_path,
|
|
output_path=str(output_path),
|
|
)
|
|
|
|
# Verify video was created
|
|
if not output_path.exists():
|
|
raise FileNotFoundError(f"Video file not created at {output_path}")
|
|
|
|
# Mark success
|
|
submission.generated_video_path = str(output_path)
|
|
submission.video_creation_end = datetime.utcnow()
|
|
submission.entry_status = "success"
|
|
db.commit()
|
|
|
|
logger.info(f"Video created for {session_id}: {output_path}")
|
|
return {"session_id": session_id, "video_path": str(output_path)}
|
|
|
|
except Exception as e:
|
|
submission.entry_status = "fail"
|
|
submission.video_creation_end = datetime.utcnow()
|
|
db.commit()
|
|
logger.error(f"Video creation failed for {session_id}: {e}")
|
|
raise
|
|
finally:
|
|
db.close()
|