pahvalentines/backend/tasks/workers.py
michael 1c996c5919 fix: add backward compatibility for in-flight tasks during GCS migration
Handle both old (audio_path) and new (audio_blob_path) keys in create_video
task to support tasks that were queued before the GCS migration deployed.
Also properly detect local vs GCS paths for photo_path.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-05 09:23:08 -06:00

473 lines
17 KiB
Python

"""Celery worker tasks for async processing."""
import json
import logging
import os
import sys
from datetime import datetime, timedelta
from pathlib import Path
import requests
from celery import shared_task
# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from app.config import settings
from app.database import SessionLocal
from app.models import Submission
from app.prompts import PET_PROMPTS, DEFAULT_PROMPT
logger = logging.getLogger(__name__)
# =============================================================================
# PERIODIC TASKS (Celery Beat)
# =============================================================================
@shared_task(bind=True, name="tasks.workers.check_timeouts")
def check_timeouts(self):
"""
Safety net: Mark submissions as failed if webhook times out.
Runs every 30 minutes via Celery Beat as a safety net for edge cases
where the per-submission check_submission_timeout task failed to run.
Primary timeout handling is event-based via check_submission_timeout.
"""
db = SessionLocal()
try:
timeout_threshold = datetime.utcnow() - timedelta(
minutes=settings.WEBHOOK_TIMEOUT_MINUTES
)
stale_submissions = (
db.query(Submission)
.filter(
Submission.sent_to_LLM.isnot(None),
Submission.received_from_LLM.is_(None),
Submission.sent_to_LLM < timeout_threshold,
)
.all()
)
count = 0
for submission in stale_submissions:
submission.entry_status = "fail"
submission.LLM_status = "timeout"
count += 1
logger.warning(f"Submission {submission.session_id} timed out")
db.commit()
logger.info(f"Marked {count} submissions as timed out")
return {"timed_out": count}
finally:
db.close()
@shared_task(bind=True, name="tasks.workers.cleanup_old_files")
def cleanup_old_files(self):
"""
Delete files older than FILE_RETENTION_DAYS from GCS.
Runs daily at 3 AM via Celery Beat.
"""
from app.services import storage
db = SessionLocal()
try:
cutoff_date = datetime.utcnow() - timedelta(days=settings.FILE_RETENTION_DAYS)
old_submissions = (
db.query(Submission).filter(Submission.created_at < cutoff_date).all()
)
deleted_count = 0
for submission in old_submissions:
# Delete associated files from GCS
for path_attr in ["photo_path", "generated_song_path", "generated_video_path"]:
blob_path = getattr(submission, path_attr)
if blob_path:
if storage.delete_blob(blob_path):
deleted_count += 1
# Delete composite record image from GCS
composite_blob_path = storage.get_composite_blob_path(submission.session_id)
if storage.delete_blob(composite_blob_path):
deleted_count += 1
db.commit()
logger.info(
f"Cleanup: deleted {deleted_count} files from {len(old_submissions)} old submissions"
)
return {"files_deleted": deleted_count, "submissions_processed": len(old_submissions)}
finally:
db.close()
# =============================================================================
# SUBMISSION PROCESSING TASKS
# =============================================================================
@shared_task(
bind=True,
name="tasks.workers.send_to_sonauto",
autoretry_for=(requests.RequestException,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 3},
)
def send_to_sonauto(self, session_id: str):
"""
Send a single submission to Sonauto API.
Called immediately after form submission. Schedules a timeout check
after successful API call.
"""
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
logger.error(f"Submission {session_id} not found")
return {"error": "not_found"}
# Get pet-specific prompt template or fall back to default
prompt_template = PET_PROMPTS.get(submission.pet_type, DEFAULT_PROMPT)
# Substitute variables
prompt = prompt_template.format(
owner_name=submission.owner_name,
pet_name=submission.pet_name,
music_vibe=submission.music_vibe,
)
payload = {
"prompt": prompt,
"instrumental": False,
"output_format": "mp3",
"webhook_url": f"{settings.WEBHOOK_BASE_URL}/api/webhook",
"enable_streaming": True,
"stream_format": "mp3",
}
try:
response = requests.post(
f"{settings.SONAUTO_API_URL}/generations/v3",
json=payload,
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
timeout=settings.API_REQUEST_TIMEOUT_SECONDS,
)
response.raise_for_status()
data = response.json()
submission.LLM_task_id = data["task_id"]
submission.sent_to_LLM = datetime.utcnow()
submission.entry_status = "processing"
db.commit()
logger.info(f"Sent {session_id} to Sonauto, task_id: {data['task_id']}")
# Schedule timeout check to fire exactly when timeout expires
check_submission_timeout.apply_async(
args=[session_id],
countdown=settings.WEBHOOK_TIMEOUT_MINUTES * 60,
)
logger.debug(
f"Scheduled timeout check for {session_id} in {settings.WEBHOOK_TIMEOUT_MINUTES} minutes"
)
return {"task_id": data["task_id"]}
except requests.RequestException as e:
# Log the error details for debugging
if hasattr(e, 'response') and e.response is not None:
logger.error(f"Sonauto API error for {session_id}: {e.response.status_code} - {e.response.text}")
logger.error(f"Payload sent: {payload}")
submission.retry_count += 1
if submission.retry_count >= settings.MAX_RETRIES:
submission.entry_status = "fail"
logger.error(
f"Submission {session_id} failed after {settings.MAX_RETRIES} retries"
)
db.commit()
raise # Trigger Celery retry with exponential backoff
finally:
db.close()
@shared_task(bind=True, name="tasks.workers.check_submission_timeout")
def check_submission_timeout(self, session_id: str):
"""
Check if a specific submission has timed out (webhook never arrived).
Scheduled by send_to_sonauto to fire exactly when the timeout expires.
"""
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
logger.warning(f"Timeout check: submission {session_id} not found")
return {"error": "not_found"}
# If webhook already received, nothing to do
if submission.received_from_LLM is not None:
logger.debug(f"Timeout check: {session_id} already received webhook")
return {"status": "webhook_received"}
# If still waiting and in processing state, mark as timed out
if submission.entry_status == "processing":
submission.entry_status = "fail"
submission.LLM_status = "timeout"
db.commit()
logger.warning(f"Submission {session_id} timed out (webhook never arrived)")
return {"status": "timed_out"}
# Already handled (e.g., already failed for another reason)
logger.debug(f"Timeout check: {session_id} already handled (status: {submission.entry_status})")
return {"status": "already_handled"}
finally:
db.close()
# =============================================================================
# POST-WEBHOOK PROCESSING CHAIN
# =============================================================================
@shared_task(
bind=True,
name="tasks.workers.fetch_generation_details",
autoretry_for=(requests.RequestException,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 3},
)
def fetch_generation_details(self, session_id: str) -> dict:
"""
Fetch full generation details from Sonauto API.
First task in the post-webhook processing chain.
"""
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
logger.error(f"Submission {session_id} not found")
raise ValueError(f"Submission {session_id} not found")
response = requests.get(
f"{settings.SONAUTO_API_URL}/generations/{submission.LLM_task_id}",
headers={"Authorization": f"Bearer {settings.SONAUTO_API_KEY}"},
timeout=30,
)
response.raise_for_status()
data = response.json()
# Store full response
submission.LLM_full_response = json.dumps(data)
# Check if generation was successful
api_status = data.get("status", "").upper()
if api_status != "SUCCESS":
submission.LLM_status = "fail"
submission.entry_status = "fail"
db.commit()
logger.error(f"Generation failed for {session_id}: status={api_status}")
raise ValueError(f"Generation failed with status: {api_status}")
# Extract song URL and lyrics
song_paths = data.get("song_paths", [])
if not song_paths:
submission.LLM_status = "fail"
submission.entry_status = "fail"
db.commit()
logger.error(f"No song_paths in response for {session_id}")
raise ValueError("No song_paths in Sonauto response")
submission.LLM_response = song_paths[0]
submission.lyrics = data.get("lyrics", "")
submission.LLM_status = "success"
db.commit()
logger.info(f"Fetched generation details for {session_id}")
# Return data needed by next task in chain
return {
"session_id": session_id,
"song_url": data["song_paths"][0],
"lyrics": data.get("lyrics", ""),
}
finally:
db.close()
@shared_task(
bind=True,
name="tasks.workers.download_audio",
autoretry_for=(requests.RequestException,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 3},
)
def download_audio(self, result: dict) -> dict:
"""
Download MP3 from Sonauto CDN and upload to GCS.
Second task in the post-webhook processing chain.
"""
from app.services import storage
session_id = result["session_id"]
song_url = result["song_url"]
# Download audio into memory
response = requests.get(song_url, timeout=120)
response.raise_for_status()
audio_data = response.content
# Upload to GCS
blob_path = storage.get_audio_blob_path(session_id)
storage.upload_bytes(audio_data, blob_path, content_type="audio/mpeg")
# Update database with blob path
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
submission.generated_song_path = blob_path
db.commit()
logger.info(f"Uploaded audio for {session_id} to GCS: {blob_path}")
finally:
db.close()
# Return data needed by next task in chain
return {
"session_id": session_id,
"audio_blob_path": blob_path,
"lyrics": result.get("lyrics", ""),
}
@shared_task(
bind=True,
name="tasks.workers.create_video",
autoretry_for=(Exception,),
retry_backoff=True,
retry_backoff_max=60,
retry_kwargs={"max_retries": 1}, # Only 1 retry for video generation
)
def create_video(self, result: dict) -> dict:
"""
Create MP4 video combining pet photo with audio.
Final task in the post-webhook processing chain.
Downloads files from GCS to temp files, generates video, uploads to GCS.
"""
from app.services import storage
session_id = result["session_id"]
# Support both old (audio_path) and new (audio_blob_path) keys for backward compatibility
audio_blob_path = result.get("audio_blob_path") or result.get("audio_path")
if not audio_blob_path:
raise ValueError("No audio path found in result (checked audio_blob_path and audio_path)")
# Detect if this is an old local path or a new GCS blob path
is_local_path = audio_blob_path.startswith("/") or audio_blob_path.startswith("../")
db = SessionLocal()
try:
submission = db.get(Submission, session_id)
if not submission:
raise ValueError(f"Submission {session_id} not found")
# Mark video creation start
submission.video_creation_start = datetime.utcnow()
db.commit()
try:
# Get audio file - either from GCS or use local path directly
if is_local_path:
# Old flow: audio is already local
audio_context = None
temp_audio_path = audio_blob_path
else:
# New flow: download from GCS
audio_context = storage.temp_download(audio_blob_path, suffix=".mp3")
temp_audio_path = audio_context.__enter__()
try:
# Download pet photo from GCS if user uploaded one (check if it's a GCS path)
pet_img_temp_path = None
pet_img_context = None
if submission.photo_path:
photo_is_local = submission.photo_path.startswith("/") or submission.photo_path.startswith("../")
if photo_is_local:
pet_img_temp_path = submission.photo_path
else:
pet_img_context = storage.temp_download(submission.photo_path, suffix=".jpg")
pet_img_temp_path = pet_img_context.__enter__()
try:
# Create temp file for video output
with storage.temp_file_for_upload(suffix=".mp4") as temp_video_path:
# Create temp file for composite if user uploaded a photo
composite_output_path = None
composite_context = None
if submission.photo_path:
composite_context = storage.temp_file_for_upload(suffix=".png")
composite_output_path = composite_context.__enter__()
try:
# Import and call the video generator
from video_generator import create_video as video_gen
video_gen(
pet_img_path=pet_img_temp_path,
audio_track_path=temp_audio_path,
output_path=temp_video_path,
composite_output_path=composite_output_path,
)
# Verify video was created
if not Path(temp_video_path).exists():
raise FileNotFoundError(f"Video file not created at {temp_video_path}")
# Upload video to GCS
video_blob_path = storage.get_video_blob_path(session_id)
storage.upload_file(temp_video_path, video_blob_path, content_type="video/mp4")
# Upload composite if it was created
if composite_output_path and Path(composite_output_path).exists():
composite_blob_path = storage.get_composite_blob_path(session_id)
storage.upload_file(composite_output_path, composite_blob_path, content_type="image/png")
finally:
if composite_context:
composite_context.__exit__(None, None, None)
finally:
if pet_img_context:
pet_img_context.__exit__(None, None, None)
finally:
if audio_context:
audio_context.__exit__(None, None, None)
# Mark success
submission.generated_video_path = video_blob_path
submission.video_creation_end = datetime.utcnow()
submission.entry_status = "success"
db.commit()
logger.info(f"Video created for {session_id} and uploaded to GCS: {video_blob_path}")
return {"session_id": session_id, "video_blob_path": video_blob_path}
except Exception as e:
submission.entry_status = "fail"
submission.video_creation_end = datetime.utcnow()
db.commit()
logger.error(f"Video creation failed for {session_id}: {e}")
raise
finally:
db.close()