pahvalentines/backend/tasks/workers.py
michael 3e89e28716 feat(video): move video generation to Cloud Run
Offload CPU-intensive FFmpeg video creation to an autoscaling Cloud Run
service to clear the 1600+ task backlog. The Celery video worker becomes
a thin HTTP client (concurrency bumped from 2 to 50) that dispatches to
Cloud Run when CLOUD_RUN_VIDEO_URL is set, with local fallback for dev.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-10 09:58:57 -06:00

503 lines
17 KiB
Python

"""Celery worker tasks for async processing."""
import json
import logging
import os
import sys
from datetime import datetime, timedelta
from pathlib import Path
import requests
from celery import shared_task
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.config import settings
from app.database import SessionLocal
from app.models import Submission
from app.prompts import PET_PROMPTS, DEFAULT_PROMPT
logger = logging.getLogger(__name__)
# =============================================================================
# PERIODIC TASKS (Celery Beat)
# =============================================================================
@shared_task(bind=True, name="tasks.workers.check_timeouts")
def check_timeouts(self):
"""
Safety net: Mark submissions as failed if webhook times out.
Runs every 30 minutes via Celery Beat as a safety net for edge cases
where the per-submission check_submission_timeout task failed to run.
Primary timeout handling is event-based via check_submission_timeout.
"""
db = SessionLocal()
try:
timeout_threshold = datetime.utcnow() - timedelta(
minutes=settings.WEBHOOK_TIMEOUT_MINUTES
)
stale_submissions = (
db.query(Submission)
.filter(
Submission.sent_to_LLM.isnot(None),
Submission.received_from_LLM.is_(None),
Submission.sent_to_LLM < timeout_threshold,
)
.all()
)
count = 0
for submission in stale_submissions:
submission.entry_status = "fail"
submission.LLM_status = "timeout"
count += 1
logger.warning(f"Submission {submission.session_id} timed out")
db.commit()
logger.info(f"Marked {count} submissions as timed out")
return {"timed_out": count}
finally:
db.close()
@shared_task(bind=True, name="tasks.workers.cleanup_old_files")
def cleanup_old_files(self):
"""
Delete files older than FILE_RETENTION_DAYS from GCS.
Runs daily at 3 AM via Celery Beat.
"""
from app.services import storage
db = SessionLocal()
try:
cutoff_date = datetime.utcnow() - timedelta(days=settings.FILE_RETENTION_DAYS)
old_submissions = (
db.query(Submission).filter(Submission.created_at < cutoff_date).all()
)
deleted_count = 0
for submission in old_submissions:
# Delete associated files from GCS
for path_attr in ["photo_path", "generated_song_path", "generated_video_path"]:
blob_path = getattr(submission, path_attr)
if blob_path:
if storage.delete_blob(blob_path):
deleted_count += 1
# Delete composite record image from GCS
composite_blob_path = storage.get_composite_blob_path(submission.session_id)
if storage.delete_blob(composite_blob_path):
deleted_count += 1
db.commit()
logger.info(
f"Cleanup: deleted {deleted_count} files from {len(old_submissions)} old submissions"
)
return {"files_deleted": deleted_count, "submissions_processed": len(old_submissions)}
finally:
db.close()
# =============================================================================
# SUBMISSION PROCESSING TASKS
# =============================================================================
@shared_task(
bind=True,
name="tasks.workers.send_to_sonauto",
autoretry_for=(requests.RequestException,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 3},
)
def send_to_sonauto(self, session_id: str):
"""
Send a single submission to Sonauto API.
Called immediately after form submission. Schedules a timeout check
after successful API call.
"""
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
logger.error(f"Submission {session_id} not found")
return {"error": "not_found"}
# Get pet-specific prompt template or fall back to default
prompt_template = PET_PROMPTS.get(submission.pet_type, DEFAULT_PROMPT)
# Substitute variables
prompt = prompt_template.format(
owner_name=submission.owner_name,
pet_name=submission.pet_name,
music_vibe=submission.music_vibe,
)
payload = {
"prompt": prompt,
"instrumental": False,
"output_format": "mp3",
"webhook_url": f"{settings.WEBHOOK_BASE_URL}/api/webhook",
"enable_streaming": True,
"stream_format": "mp3",
}
try:
response = requests.post(
f"{settings.SONAUTO_API_URL}/generations/v3",
json=payload,
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
timeout=settings.API_REQUEST_TIMEOUT_SECONDS,
)
response.raise_for_status()
data = response.json()
submission.LLM_task_id = data["task_id"]
submission.sent_to_LLM = datetime.utcnow()
submission.entry_status = "processing"
db.commit()
logger.info(f"Sent {session_id} to Sonauto, task_id: {data['task_id']}")
# Schedule timeout check to fire exactly when timeout expires
check_submission_timeout.apply_async(
args=[session_id],
countdown=settings.WEBHOOK_TIMEOUT_MINUTES * 60,
)
logger.debug(
f"Scheduled timeout check for {session_id} in {settings.WEBHOOK_TIMEOUT_MINUTES} minutes"
)
return {"task_id": data["task_id"]}
except requests.RequestException as e:
# Log the error details for debugging
if hasattr(e, 'response') and e.response is not None:
logger.error(f"Sonauto API error for {session_id}: {e.response.status_code} - {e.response.text}")
logger.error(f"Payload sent: {payload}")
submission.retry_count += 1
if submission.retry_count >= settings.MAX_RETRIES:
submission.entry_status = "fail"
logger.error(
f"Submission {session_id} failed after {settings.MAX_RETRIES} retries"
)
db.commit()
raise # Trigger Celery retry with exponential backoff
finally:
db.close()
@shared_task(bind=True, name="tasks.workers.check_submission_timeout")
def check_submission_timeout(self, session_id: str):
"""
Check if a specific submission has timed out (webhook never arrived).
Scheduled by send_to_sonauto to fire exactly when the timeout expires.
"""
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
logger.warning(f"Timeout check: submission {session_id} not found")
return {"error": "not_found"}
# If webhook already received, nothing to do
if submission.received_from_LLM is not None:
logger.debug(f"Timeout check: {session_id} already received webhook")
return {"status": "webhook_received"}
# If still waiting and in processing state, mark as timed out
if submission.entry_status == "processing":
submission.entry_status = "fail"
submission.LLM_status = "timeout"
db.commit()
logger.warning(f"Submission {session_id} timed out (webhook never arrived)")
return {"status": "timed_out"}
# Already handled (e.g., already failed for another reason)
logger.debug(f"Timeout check: {session_id} already handled (status: {submission.entry_status})")
return {"status": "already_handled"}
finally:
db.close()
# =============================================================================
# POST-WEBHOOK PROCESSING CHAIN
# =============================================================================
@shared_task(
bind=True,
name="tasks.workers.fetch_generation_details",
autoretry_for=(requests.RequestException,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 3},
)
def fetch_generation_details(self, session_id: str) -> dict:
"""
Fetch full generation details from Sonauto API.
First task in the post-webhook processing chain.
"""
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
logger.error(f"Submission {session_id} not found")
raise ValueError(f"Submission {session_id} not found")
response = requests.get(
f"{settings.SONAUTO_API_URL}/generations/{submission.LLM_task_id}",
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
timeout=30,
)
response.raise_for_status()
data = response.json()
# Store full response
submission.LLM_full_response = json.dumps(data)
# Check if generation was successful
api_status = data.get("status", "").upper()
if api_status != "SUCCESS":
submission.LLM_status = "fail"
submission.entry_status = "fail"
db.commit()
logger.error(f"Generation failed for {session_id}: status={api_status}")
raise ValueError(f"Generation failed with status: {api_status}")
# Extract song URL and lyrics
song_paths = data.get("song_paths", [])
if not song_paths:
submission.LLM_status = "fail"
submission.entry_status = "fail"
db.commit()
logger.error(f"No song_paths in response for {session_id}")
raise ValueError("No song_paths in Sonauto response")
submission.LLM_response = song_paths[0]
submission.lyrics = data.get("lyrics", "")
submission.LLM_status = "success"
db.commit()
logger.info(f"Fetched generation details for {session_id}")
# Return data needed by next task in chain
return {
"session_id": session_id,
"song_url": data["song_paths"][0],
"lyrics": data.get("lyrics", ""),
}
finally:
db.close()
@shared_task(
bind=True,
name="tasks.workers.download_audio",
autoretry_for=(requests.RequestException,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 3},
)
def download_audio(self, result: dict) -> dict:
"""
Download MP3 from Sonauto CDN and upload to GCS.
Second task in the post-webhook processing chain.
"""
from app.services import storage
session_id = result["session_id"]
song_url = result["song_url"]
# Download audio into memory
response = requests.get(song_url, timeout=120)
response.raise_for_status()
audio_data = response.content
# Upload to GCS
blob_path = storage.get_audio_blob_path(session_id)
storage.upload_bytes(audio_data, blob_path, content_type="audio/mpeg")
# Update database with blob path
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
submission.generated_song_path = blob_path
db.commit()
logger.info(f"Uploaded audio for {session_id} to GCS: {blob_path}")
finally:
db.close()
# Return data needed by next task in chain
return {
"session_id": session_id,
"audio_blob_path": blob_path,
"lyrics": result.get("lyrics", ""),
}
@shared_task(
bind=True,
name="tasks.workers.create_video",
autoretry_for=(requests.RequestException, requests.Timeout),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 2},
)
def create_video(self, result: dict) -> dict:
"""
Create MP4 video combining pet photo with audio.
Final task in the post-webhook processing chain.
If CLOUD_RUN_VIDEO_URL is set, dispatches to Cloud Run service.
Otherwise falls back to local FFmpeg generation.
"""
from app.services import storage
session_id = result["session_id"]
audio_blob_path = result.get("audio_blob_path") or result.get("audio_path")
if not audio_blob_path:
raise ValueError("No audio path found in result (checked audio_blob_path and audio_path)")
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
raise ValueError(f"Submission {session_id} not found")
# Mark video creation start
submission.video_creation_start = datetime.utcnow()
db.commit()
try:
if settings.CLOUD_RUN_VIDEO_URL:
# ---- Cloud Run path ----
video_blob_path = _create_video_cloud_run(
session_id, audio_blob_path, submission.photo_path
)
else:
# ---- Local fallback ----
video_blob_path = _create_video_local(
session_id, audio_blob_path, submission.photo_path, storage
)
# Mark success
submission.generated_video_path = video_blob_path
submission.video_creation_end = datetime.utcnow()
submission.entry_status = "success"
db.commit()
logger.info(f"Video created for {session_id}: {video_blob_path}")
return {"session_id": session_id, "video_blob_path": video_blob_path}
except Exception as e:
submission.entry_status = "fail"
submission.video_creation_end = datetime.utcnow()
db.commit()
logger.error(f"Video creation failed for {session_id}: {e}")
raise
finally:
db.close()
def _create_video_cloud_run(session_id: str, audio_blob_path: str, photo_blob_path: str | None) -> str:
"""Dispatch video generation to Cloud Run service and return the video blob path."""
payload = {
"session_id": session_id,
"audio_blob_path": audio_blob_path,
"photo_blob_path": photo_blob_path,
}
headers = {"X-API-Key": settings.VIDEO_SERVICE_API_KEY}
logger.info(f"Dispatching video generation to Cloud Run for {session_id}")
response = requests.post(
f"{settings.CLOUD_RUN_VIDEO_URL}/generate",
json=payload,
headers=headers,
timeout=settings.CLOUD_RUN_VIDEO_TIMEOUT,
)
response.raise_for_status()
data = response.json()
logger.info(f"Cloud Run returned video for {session_id}: {data['video_blob_path']}")
return data["video_blob_path"]
def _create_video_local(session_id: str, audio_blob_path: str, photo_path: str | None, storage) -> str:
"""Generate video locally using FFmpeg (dev fallback)."""
is_local_path = audio_blob_path.startswith("/") or audio_blob_path.startswith("../")
# Get audio file
if is_local_path:
audio_context = None
temp_audio_path = audio_blob_path
else:
audio_context = storage.temp_download(audio_blob_path, suffix=".mp3")
temp_audio_path = audio_context.__enter__()
try:
# Get pet photo
pet_img_temp_path = None
pet_img_context = None
if photo_path:
photo_is_local = photo_path.startswith("/") or photo_path.startswith("../")
if photo_is_local:
pet_img_temp_path = photo_path
else:
pet_img_context = storage.temp_download(photo_path, suffix=".jpg")
pet_img_temp_path = pet_img_context.__enter__()
try:
with storage.temp_file_for_upload(suffix=".mp4") as temp_video_path:
composite_output_path = None
composite_context = None
if photo_path:
composite_context = storage.temp_file_for_upload(suffix=".png")
composite_output_path = composite_context.__enter__()
try:
from video_generator import create_video as video_gen
video_gen(
pet_img_path=pet_img_temp_path,
audio_track_path=temp_audio_path,
output_path=temp_video_path,
composite_output_path=composite_output_path,
)
if not Path(temp_video_path).exists():
raise FileNotFoundError(f"Video file not created at {temp_video_path}")
video_blob_path = storage.get_video_blob_path(session_id)
storage.upload_file(temp_video_path, video_blob_path, content_type="video/mp4")
if composite_output_path and Path(composite_output_path).exists():
composite_blob_path = storage.get_composite_blob_path(session_id)
storage.upload_file(composite_output_path, composite_blob_path, content_type="image/png")
finally:
if composite_context:
composite_context.__exit__(None, None, None)
finally:
if pet_img_context:
pet_img_context.__exit__(None, None, None)
finally:
if audio_context:
audio_context.__exit__(None, None, None)
return video_blob_path