Handle both old (audio_path) and new (audio_blob_path) keys in create_video task to support tasks that were queued before the GCS migration deployed. Also properly detect local vs GCS paths for photo_path. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
473 lines
17 KiB
Python
473 lines
17 KiB
Python
"""Celery worker tasks for async processing."""
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import sys
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
|
|
import requests
|
|
from celery import shared_task
|
|
|
|
# Add parent directory to path for imports
|
|
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
|
|
|
from app.config import settings
|
|
from app.database import SessionLocal
|
|
from app.models import Submission
|
|
from app.prompts import PET_PROMPTS, DEFAULT_PROMPT
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# =============================================================================
|
|
# PERIODIC TASKS (Celery Beat)
|
|
# =============================================================================
|
|
|
|
|
|
@shared_task(bind=True, name="tasks.workers.check_timeouts")
|
|
def check_timeouts(self):
|
|
"""
|
|
Safety net: Mark submissions as failed if webhook times out.
|
|
|
|
Runs every 30 minutes via Celery Beat as a safety net for edge cases
|
|
where the per-submission check_submission_timeout task failed to run.
|
|
Primary timeout handling is event-based via check_submission_timeout.
|
|
"""
|
|
db = SessionLocal()
|
|
try:
|
|
timeout_threshold = datetime.utcnow() - timedelta(
|
|
minutes=settings.WEBHOOK_TIMEOUT_MINUTES
|
|
)
|
|
|
|
stale_submissions = (
|
|
db.query(Submission)
|
|
.filter(
|
|
Submission.sent_to_LLM.isnot(None),
|
|
Submission.received_from_LLM.is_(None),
|
|
Submission.sent_to_LLM < timeout_threshold,
|
|
)
|
|
.all()
|
|
)
|
|
|
|
count = 0
|
|
for submission in stale_submissions:
|
|
submission.entry_status = "fail"
|
|
submission.LLM_status = "timeout"
|
|
count += 1
|
|
logger.warning(f"Submission {submission.session_id} timed out")
|
|
|
|
db.commit()
|
|
logger.info(f"Marked {count} submissions as timed out")
|
|
return {"timed_out": count}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@shared_task(bind=True, name="tasks.workers.cleanup_old_files")
|
|
def cleanup_old_files(self):
|
|
"""
|
|
Delete files older than FILE_RETENTION_DAYS from GCS.
|
|
|
|
Runs daily at 3 AM via Celery Beat.
|
|
"""
|
|
from app.services import storage
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
cutoff_date = datetime.utcnow() - timedelta(days=settings.FILE_RETENTION_DAYS)
|
|
|
|
old_submissions = (
|
|
db.query(Submission).filter(Submission.created_at < cutoff_date).all()
|
|
)
|
|
|
|
deleted_count = 0
|
|
for submission in old_submissions:
|
|
# Delete associated files from GCS
|
|
for path_attr in ["photo_path", "generated_song_path", "generated_video_path"]:
|
|
blob_path = getattr(submission, path_attr)
|
|
if blob_path:
|
|
if storage.delete_blob(blob_path):
|
|
deleted_count += 1
|
|
|
|
# Delete composite record image from GCS
|
|
composite_blob_path = storage.get_composite_blob_path(submission.session_id)
|
|
if storage.delete_blob(composite_blob_path):
|
|
deleted_count += 1
|
|
|
|
db.commit()
|
|
logger.info(
|
|
f"Cleanup: deleted {deleted_count} files from {len(old_submissions)} old submissions"
|
|
)
|
|
return {"files_deleted": deleted_count, "submissions_processed": len(old_submissions)}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# =============================================================================
|
|
# SUBMISSION PROCESSING TASKS
|
|
# =============================================================================
|
|
|
|
|
|
@shared_task(
|
|
bind=True,
|
|
name="tasks.workers.send_to_sonauto",
|
|
autoretry_for=(requests.RequestException,),
|
|
retry_backoff=True,
|
|
retry_backoff_max=60,
|
|
retry_kwargs={"max_retries": 3},
|
|
)
|
|
def send_to_sonauto(self, session_id: str):
|
|
"""
|
|
Send a single submission to Sonauto API.
|
|
|
|
Called immediately after form submission. Schedules a timeout check
|
|
after successful API call.
|
|
"""
|
|
db = SessionLocal()
|
|
try:
|
|
submission = db.get(Submission, session_id)
|
|
if not submission:
|
|
logger.error(f"Submission {session_id} not found")
|
|
return {"error": "not_found"}
|
|
|
|
# Get pet-specific prompt template or fall back to default
|
|
prompt_template = PET_PROMPTS.get(submission.pet_type, DEFAULT_PROMPT)
|
|
|
|
# Substitute variables
|
|
prompt = prompt_template.format(
|
|
owner_name=submission.owner_name,
|
|
pet_name=submission.pet_name,
|
|
music_vibe=submission.music_vibe,
|
|
)
|
|
|
|
payload = {
|
|
"prompt": prompt,
|
|
"instrumental": False,
|
|
"output_format": "mp3",
|
|
"webhook_url": f"{settings.WEBHOOK_BASE_URL}/api/webhook",
|
|
"enable_streaming": True,
|
|
"stream_format": "mp3",
|
|
}
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"{settings.SONAUTO_API_URL}/generations/v3",
|
|
json=payload,
|
|
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
|
|
timeout=settings.API_REQUEST_TIMEOUT_SECONDS,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
submission.LLM_task_id = data["task_id"]
|
|
submission.sent_to_LLM = datetime.utcnow()
|
|
submission.entry_status = "processing"
|
|
db.commit()
|
|
|
|
logger.info(f"Sent {session_id} to Sonauto, task_id: {data['task_id']}")
|
|
|
|
# Schedule timeout check to fire exactly when timeout expires
|
|
check_submission_timeout.apply_async(
|
|
args=[session_id],
|
|
countdown=settings.WEBHOOK_TIMEOUT_MINUTES * 60,
|
|
)
|
|
logger.debug(
|
|
f"Scheduled timeout check for {session_id} in {settings.WEBHOOK_TIMEOUT_MINUTES} minutes"
|
|
)
|
|
|
|
return {"task_id": data["task_id"]}
|
|
|
|
except requests.RequestException as e:
|
|
# Log the error details for debugging
|
|
if hasattr(e, 'response') and e.response is not None:
|
|
logger.error(f"Sonauto API error for {session_id}: {e.response.status_code} - {e.response.text}")
|
|
logger.error(f"Payload sent: {payload}")
|
|
submission.retry_count += 1
|
|
if submission.retry_count >= settings.MAX_RETRIES:
|
|
submission.entry_status = "fail"
|
|
logger.error(
|
|
f"Submission {session_id} failed after {settings.MAX_RETRIES} retries"
|
|
)
|
|
db.commit()
|
|
raise # Trigger Celery retry with exponential backoff
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@shared_task(bind=True, name="tasks.workers.check_submission_timeout")
|
|
def check_submission_timeout(self, session_id: str):
|
|
"""
|
|
Check if a specific submission has timed out (webhook never arrived).
|
|
|
|
Scheduled by send_to_sonauto to fire exactly when the timeout expires.
|
|
"""
|
|
db = SessionLocal()
|
|
try:
|
|
submission = db.get(Submission, session_id)
|
|
if not submission:
|
|
logger.warning(f"Timeout check: submission {session_id} not found")
|
|
return {"error": "not_found"}
|
|
|
|
# If webhook already received, nothing to do
|
|
if submission.received_from_LLM is not None:
|
|
logger.debug(f"Timeout check: {session_id} already received webhook")
|
|
return {"status": "webhook_received"}
|
|
|
|
# If still waiting and in processing state, mark as timed out
|
|
if submission.entry_status == "processing":
|
|
submission.entry_status = "fail"
|
|
submission.LLM_status = "timeout"
|
|
db.commit()
|
|
logger.warning(f"Submission {session_id} timed out (webhook never arrived)")
|
|
return {"status": "timed_out"}
|
|
|
|
# Already handled (e.g., already failed for another reason)
|
|
logger.debug(f"Timeout check: {session_id} already handled (status: {submission.entry_status})")
|
|
return {"status": "already_handled"}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
# =============================================================================
|
|
# POST-WEBHOOK PROCESSING CHAIN
|
|
# =============================================================================
|
|
|
|
|
|
@shared_task(
|
|
bind=True,
|
|
name="tasks.workers.fetch_generation_details",
|
|
autoretry_for=(requests.RequestException,),
|
|
retry_backoff=True,
|
|
retry_backoff_max=60,
|
|
retry_kwargs={"max_retries": 3},
|
|
)
|
|
def fetch_generation_details(self, session_id: str) -> dict:
|
|
"""
|
|
Fetch full generation details from Sonauto API.
|
|
|
|
First task in the post-webhook processing chain.
|
|
"""
|
|
db = SessionLocal()
|
|
try:
|
|
submission = db.get(Submission, session_id)
|
|
if not submission:
|
|
logger.error(f"Submission {session_id} not found")
|
|
raise ValueError(f"Submission {session_id} not found")
|
|
|
|
response = requests.get(
|
|
f"{settings.SONAUTO_API_URL}/generations/{submission.LLM_task_id}",
|
|
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
|
|
timeout=30,
|
|
)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
|
|
# Store full response
|
|
submission.LLM_full_response = json.dumps(data)
|
|
|
|
# Check if generation was successful
|
|
api_status = data.get("status", "").upper()
|
|
if api_status != "SUCCESS":
|
|
submission.LLM_status = "fail"
|
|
submission.entry_status = "fail"
|
|
db.commit()
|
|
logger.error(f"Generation failed for {session_id}: status={api_status}")
|
|
raise ValueError(f"Generation failed with status: {api_status}")
|
|
|
|
# Extract song URL and lyrics
|
|
song_paths = data.get("song_paths", [])
|
|
if not song_paths:
|
|
submission.LLM_status = "fail"
|
|
submission.entry_status = "fail"
|
|
db.commit()
|
|
logger.error(f"No song_paths in response for {session_id}")
|
|
raise ValueError("No song_paths in Sonauto response")
|
|
|
|
submission.LLM_response = song_paths[0]
|
|
submission.lyrics = data.get("lyrics", "")
|
|
submission.LLM_status = "success"
|
|
db.commit()
|
|
|
|
logger.info(f"Fetched generation details for {session_id}")
|
|
|
|
# Return data needed by next task in chain
|
|
return {
|
|
"session_id": session_id,
|
|
"song_url": data["song_paths"][0],
|
|
"lyrics": data.get("lyrics", ""),
|
|
}
|
|
finally:
|
|
db.close()
|
|
|
|
|
|
@shared_task(
|
|
bind=True,
|
|
name="tasks.workers.download_audio",
|
|
autoretry_for=(requests.RequestException,),
|
|
retry_backoff=True,
|
|
retry_backoff_max=60,
|
|
retry_kwargs={"max_retries": 3},
|
|
)
|
|
def download_audio(self, result: dict) -> dict:
|
|
"""
|
|
Download MP3 from Sonauto CDN and upload to GCS.
|
|
|
|
Second task in the post-webhook processing chain.
|
|
"""
|
|
from app.services import storage
|
|
|
|
session_id = result["session_id"]
|
|
song_url = result["song_url"]
|
|
|
|
# Download audio into memory
|
|
response = requests.get(song_url, timeout=120)
|
|
response.raise_for_status()
|
|
audio_data = response.content
|
|
|
|
# Upload to GCS
|
|
blob_path = storage.get_audio_blob_path(session_id)
|
|
storage.upload_bytes(audio_data, blob_path, content_type="audio/mpeg")
|
|
|
|
# Update database with blob path
|
|
db = SessionLocal()
|
|
try:
|
|
submission = db.get(Submission, session_id)
|
|
submission.generated_song_path = blob_path
|
|
db.commit()
|
|
logger.info(f"Uploaded audio for {session_id} to GCS: {blob_path}")
|
|
finally:
|
|
db.close()
|
|
|
|
# Return data needed by next task in chain
|
|
return {
|
|
"session_id": session_id,
|
|
"audio_blob_path": blob_path,
|
|
"lyrics": result.get("lyrics", ""),
|
|
}
|
|
|
|
|
|
@shared_task(
|
|
bind=True,
|
|
name="tasks.workers.create_video",
|
|
autoretry_for=(Exception,),
|
|
retry_backoff=True,
|
|
retry_backoff_max=60,
|
|
retry_kwargs={"max_retries": 1}, # Only 1 retry for video generation
|
|
)
|
|
def create_video(self, result: dict) -> dict:
|
|
"""
|
|
Create MP4 video combining pet photo with audio.
|
|
|
|
Final task in the post-webhook processing chain.
|
|
Downloads files from GCS to temp files, generates video, uploads to GCS.
|
|
"""
|
|
from app.services import storage
|
|
|
|
session_id = result["session_id"]
|
|
# Support both old (audio_path) and new (audio_blob_path) keys for backward compatibility
|
|
audio_blob_path = result.get("audio_blob_path") or result.get("audio_path")
|
|
if not audio_blob_path:
|
|
raise ValueError("No audio path found in result (checked audio_blob_path and audio_path)")
|
|
|
|
# Detect if this is an old local path or a new GCS blob path
|
|
is_local_path = audio_blob_path.startswith("/") or audio_blob_path.startswith("../")
|
|
|
|
db = SessionLocal()
|
|
try:
|
|
submission = db.get(Submission, session_id)
|
|
if not submission:
|
|
raise ValueError(f"Submission {session_id} not found")
|
|
|
|
# Mark video creation start
|
|
submission.video_creation_start = datetime.utcnow()
|
|
db.commit()
|
|
|
|
try:
|
|
# Get audio file - either from GCS or use local path directly
|
|
if is_local_path:
|
|
# Old flow: audio is already local
|
|
audio_context = None
|
|
temp_audio_path = audio_blob_path
|
|
else:
|
|
# New flow: download from GCS
|
|
audio_context = storage.temp_download(audio_blob_path, suffix=".mp3")
|
|
temp_audio_path = audio_context.__enter__()
|
|
|
|
try:
|
|
# Download pet photo from GCS if user uploaded one (check if it's a GCS path)
|
|
pet_img_temp_path = None
|
|
pet_img_context = None
|
|
|
|
if submission.photo_path:
|
|
photo_is_local = submission.photo_path.startswith("/") or submission.photo_path.startswith("../")
|
|
if photo_is_local:
|
|
pet_img_temp_path = submission.photo_path
|
|
else:
|
|
pet_img_context = storage.temp_download(submission.photo_path, suffix=".jpg")
|
|
pet_img_temp_path = pet_img_context.__enter__()
|
|
|
|
try:
|
|
# Create temp file for video output
|
|
with storage.temp_file_for_upload(suffix=".mp4") as temp_video_path:
|
|
# Create temp file for composite if user uploaded a photo
|
|
composite_output_path = None
|
|
composite_context = None
|
|
|
|
if submission.photo_path:
|
|
composite_context = storage.temp_file_for_upload(suffix=".png")
|
|
composite_output_path = composite_context.__enter__()
|
|
|
|
try:
|
|
# Import and call the video generator
|
|
from video_generator import create_video as video_gen
|
|
|
|
video_gen(
|
|
pet_img_path=pet_img_temp_path,
|
|
audio_track_path=temp_audio_path,
|
|
output_path=temp_video_path,
|
|
composite_output_path=composite_output_path,
|
|
)
|
|
|
|
# Verify video was created
|
|
if not Path(temp_video_path).exists():
|
|
raise FileNotFoundError(f"Video file not created at {temp_video_path}")
|
|
|
|
# Upload video to GCS
|
|
video_blob_path = storage.get_video_blob_path(session_id)
|
|
storage.upload_file(temp_video_path, video_blob_path, content_type="video/mp4")
|
|
|
|
# Upload composite if it was created
|
|
if composite_output_path and Path(composite_output_path).exists():
|
|
composite_blob_path = storage.get_composite_blob_path(session_id)
|
|
storage.upload_file(composite_output_path, composite_blob_path, content_type="image/png")
|
|
|
|
finally:
|
|
if composite_context:
|
|
composite_context.__exit__(None, None, None)
|
|
|
|
finally:
|
|
if pet_img_context:
|
|
pet_img_context.__exit__(None, None, None)
|
|
|
|
finally:
|
|
if audio_context:
|
|
audio_context.__exit__(None, None, None)
|
|
|
|
# Mark success
|
|
submission.generated_video_path = video_blob_path
|
|
submission.video_creation_end = datetime.utcnow()
|
|
submission.entry_status = "success"
|
|
db.commit()
|
|
|
|
logger.info(f"Video created for {session_id} and uploaded to GCS: {video_blob_path}")
|
|
return {"session_id": session_id, "video_blob_path": video_blob_path}
|
|
|
|
except Exception as e:
|
|
submission.entry_status = "fail"
|
|
submission.video_creation_end = datetime.utcnow()
|
|
db.commit()
|
|
logger.error(f"Video creation failed for {session_id}: {e}")
|
|
raise
|
|
finally:
|
|
db.close()
|