video-accessibility/backend/app/tasks/rerender_accessible_video.py
Vadym Samoilenko 31199f8705 chore: push all session changes — backend hardening, tests, apache config, deploy scripts
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-30 15:52:14 +01:00

627 lines
26 KiB
Python

"""Celery task for re-rendering accessible video with QC changes."""
import asyncio
import os
import tempfile
from datetime import datetime
from celery.result import allow_join_result
from motor.motor_asyncio import AsyncIOMotorClient
from pydub import AudioSegment
from ..core.config import settings
from ..core.logging import get_logger
from ..lib.vtt import VTTParser
from ..models.job import (
AccessibleVideoEditState,
JobStatus,
)
from ..services.gcs import gcs_path, gcs_service
from ..services.video_renderer import video_renderer_service
from ..services.vtt_retimer import vtt_retimer_service
from ..services.whisper_service import whisper_service
from . import celery_app
from ._websocket_bridge import broadcast_status_update
from .render_accessible_video import (
_dispatch_whisper_transcription,
_extract_audio_for_whisper,
)
from .tts_synthesis import (
parse_ad_cues,
parse_cue_index_from_blob_name,
synthesize_cue_task,
)
logger = get_logger(__name__)
@celery_app.task(bind=True, time_limit=7200, soft_time_limit=7000)
def rerender_accessible_video_task(
self,
job_id: str,
language: str,
regenerate_cue_indices: list[int],
whisper_refine: bool = False
):
"""
Re-render accessible video during QC review with selective TTS regeneration.
This task:
1. If regenerate_cue_indices not empty: synthesize new TTS for those cues
2. Download source video and existing segments/MP3s
3. If whisper_refine: run Whisper pause point refinement
4. Re-render video using updated pause points and new/existing TTS
5. Update job status back to PENDING_QC
Args:
job_id: Job ID
language: Language being re-rendered
regenerate_cue_indices: List of cue indices to regenerate TTS for
whisper_refine: Whether to run Whisper pause point refinement
"""
logger.info(
f"Starting accessible video re-render for job {job_id}/{language}: "
f"regenerate={regenerate_cue_indices}, whisper_refine={whisper_refine}"
)
try:
result = asyncio.run(_async_rerender_accessible_video(
job_id, language, regenerate_cue_indices, whisper_refine
))
logger.info(f"Accessible video re-render completed for job {job_id}/{language}")
return result
except Exception as e:
logger.error(f"Accessible video re-render failed for job {job_id}/{language}: {e}")
import traceback
logger.error(f"Full traceback: {traceback.format_exc()}")
# Update job status back to PENDING_QC with error
asyncio.run(_mark_rerender_failed(job_id, language, str(e)))
raise
async def _mark_rerender_failed(job_id: str, language: str, error_message: str):
"""Mark re-render as failed and return to PENDING_QC."""
client = AsyncIOMotorClient(settings.mongodb_uri)
db = client[settings.mongodb_db]
try:
await db.jobs.update_one(
{"_id": job_id},
{
"$set": {
"status": JobStatus.PENDING_QC.value,
f"outputs.{language}.accessible_video_edit_state.last_render_error": error_message,
"updated_at": datetime.utcnow()
},
"$push": {
"review.history": {
"at": datetime.utcnow(),
"status": JobStatus.PENDING_QC.value,
"by": "system",
"notes": f"Re-render failed for {language}: {error_message[:200]}"
}
}
}
)
job_doc = await db.jobs.find_one({"_id": job_id})
broadcast_status_update(
job_id,
JobStatus.PENDING_QC.value,
job_title=job_doc.get("title") if job_doc else None,
message=f"Re-render failed: {error_message[:100]}"
)
finally:
client.close()
async def _async_rerender_accessible_video(
job_id: str,
language: str,
regenerate_cue_indices: list[int],
whisper_refine: bool
):
"""Async implementation of accessible video re-rendering."""
logger.info(f"Async re-render started for job {job_id}/{language}")
client = AsyncIOMotorClient(settings.mongodb_uri)
db = client[settings.mongodb_db]
try:
# Get job details
job_doc = await db.jobs.find_one({"_id": job_id})
if not job_doc:
raise ValueError(f"Job {job_id} not found")
job_title = job_doc.get("title", "Untitled Job")
lang_output = job_doc.get("outputs", {}).get(language)
if not lang_output:
raise ValueError(f"No outputs found for language {language}")
edit_state = lang_output.get("accessible_video_edit_state")
if not edit_state:
raise ValueError(f"No edit state found for language {language}")
# Use TMPDIR env var if set
temp_base = os.environ.get('TMPDIR', None)
with tempfile.TemporaryDirectory(dir=temp_base) as temp_dir:
# 1. Download source video
source_video_gcs = job_doc["source"]["gcs_uri"]
source_blob_path = source_video_gcs.replace(f"gs://{settings.gcs_bucket}/", "")
source_video_path = os.path.join(temp_dir, "source.mp4")
logger.info(f"Downloading source video from {source_blob_path}")
source_blob = gcs_service.bucket.blob(source_blob_path)
source_blob.download_to_filename(source_video_path)
# 2. Regenerate TTS for queued cues (if any)
if regenerate_cue_indices:
logger.info(f"Regenerating TTS for cues: {regenerate_cue_indices}")
regen_results = await _regenerate_tts_cues(
job_id, language, regenerate_cue_indices, job_doc, db, temp_dir
)
# Update manifest with new GCS URIs for regenerated cues
if regen_results:
job_doc_after_regen = await db.jobs.find_one({"_id": job_id})
current_manifest = (
job_doc_after_regen.get("outputs", {})
.get(language, {})
.get("ad_cue_manifest") or []
)
manifest_by_idx = {e["cue_index"]: e for e in current_manifest}
for r in regen_results:
if r.get("success") and r.get("gcs_uri"):
manifest_by_idx[r["cue_index"]] = {
"cue_index": r["cue_index"],
"gcs_uri": r["gcs_uri"],
"text": r.get("text", "")[:80],
"duration_s": r.get("duration", 0.0)
}
updated_manifest = sorted(manifest_by_idx.values(), key=lambda e: e["cue_index"])
await db.jobs.update_one(
{"_id": job_id},
{"$set": {f"outputs.{language}.ad_cue_manifest": updated_manifest}}
)
logger.info(f"Updated ad_cue_manifest with {len(regen_results)} regenerated cues")
# Clear regeneration queue after successful synthesis
await db.jobs.update_one(
{"_id": job_id},
{
"$set": {
f"outputs.{language}.accessible_video_edit_state.tts_regeneration_queue": [],
"updated_at": datetime.utcnow()
}
}
)
# 3. Download AD VTT and per-cue MP3s
ad_vtt_gcs = lang_output.get("ad_vtt_gcs")
if not ad_vtt_gcs:
raise ValueError(f"No AD VTT found for language {language}")
ad_blob_path = ad_vtt_gcs.replace(f"gs://{settings.gcs_bucket}/", "")
ad_blob = gcs_service.bucket.blob(ad_blob_path)
ad_vtt_content = ad_blob.download_as_text()
# Download per-cue MP3s
ad_cues_prefix = lang_output.get("ad_cues_gcs_prefix")
if not ad_cues_prefix:
raise ValueError(f"No AD cue segments found for language {language}")
ad_segments = []
cue_durations = []
# Re-fetch job doc to get updated manifest after TTS regen
job_doc = await db.jobs.find_one({"_id": job_id})
lang_output = job_doc["outputs"].get(language, {})
ad_cue_manifest = lang_output.get("ad_cue_manifest")
if ad_cue_manifest:
logger.info(f"Using ad_cue_manifest ({len(ad_cue_manifest)} entries) for MP3 download")
for entry in sorted(ad_cue_manifest, key=lambda e: e["cue_index"]):
cue_index = entry["cue_index"]
gcs_uri = entry["gcs_uri"]
blob_path = gcs_uri.replace(f"gs://{settings.gcs_bucket}/", "")
local_path = os.path.join(temp_dir, f"cue_{cue_index}.mp3")
gcs_service.bucket.blob(blob_path).download_to_filename(local_path)
ad_segments.append((cue_index, local_path))
audio = AudioSegment.from_mp3(local_path)
cue_durations.append(len(audio) / 1000.0)
else:
logger.warning(
f"No ad_cue_manifest for job {job_id}/{language}"
"falling back to legacy index-based blob listing. "
"Cue insertions/deletions may cause MP3/VTT desync."
)
prefix_path = ad_cues_prefix.replace(f"gs://{settings.gcs_bucket}/", "")
blobs = list(gcs_service.bucket.list_blobs(prefix=prefix_path))
cue_blobs = [
(b, parse_cue_index_from_blob_name(b.name))
for b in blobs if b.name.endswith(".mp3")
]
cue_blobs = [(b, idx) for b, idx in cue_blobs if idx is not None]
cue_blobs.sort(key=lambda x: x[1])
for blob, cue_index in cue_blobs:
local_path = os.path.join(temp_dir, f"cue_{cue_index}.mp3")
blob.download_to_filename(local_path)
ad_segments.append((cue_index, local_path))
audio = AudioSegment.from_mp3(local_path)
cue_durations.append(len(audio) / 1000.0)
logger.info(f"Downloaded {len(ad_segments)} AD cue segments")
# Validate VTT cue count matches MP3 count
vtt_cues = VTTParser.parse(ad_vtt_content)
downloaded_indices = set(idx for idx, _ in ad_segments)
if len(vtt_cues) != len(ad_segments):
missing_indices = set(range(len(vtt_cues))) - downloaded_indices
logger.warning(
f"VTT cue count ({len(vtt_cues)}) does not match MP3 count ({len(ad_segments)}). "
f"Missing MP3s for cue indices: {sorted(missing_indices)}. "
f"This may happen when a new AD cue is inserted but TTS wasn't regenerated for shifted cues."
)
# 4. Build placements with adjusted pause points
method = lang_output.get("accessible_video_method", "pause_insert")
pause_points = edit_state.get("pause_points", [])
if not pause_points:
logger.info("No pause points in edit state — using VTT cue start times for all placements")
placements = _build_placements_with_adjustments(
ad_vtt_content, cue_durations, pause_points
)
logger.info(f"Built {len(placements)} placements with adjusted pause points")
analysis = {
"method": method,
"method_rationale": "QC re-render with user adjustments",
"placements": placements,
"total_added_duration": sum(cue_durations) if method == "pause_insert" else 0,
"warnings": []
}
# 5. Optionally run Whisper refinement
if whisper_refine and method == "pause_insert":
logger.info("Running Whisper pause point refinement...")
analysis, whisper_warnings = await _refine_pause_points_for_rerender(
job_id, source_video_path, analysis, db, temp_dir
)
if whisper_warnings:
analysis["warnings"] = analysis.get("warnings", []) + whisper_warnings
logger.info(f"Whisper refinement complete with {len(whisper_warnings)} warnings")
# 6. Render accessible video (persist segments again for future edits)
output_video_path = os.path.join(temp_dir, "accessible_video.mp4")
gcs_segment_prefix = gcs_path(job_doc, language, "segments") + "/"
logger.info(f"Re-rendering accessible video using {method} method...")
rendered_path, updated_placements, segment_metadata, new_pause_points = await video_renderer_service.render_accessible_video(
source_video_path,
ad_segments,
analysis,
output_video_path,
persist_segments=True,
gcs_segment_prefix=gcs_segment_prefix
)
if updated_placements:
analysis["placements"] = updated_placements
# 7. Upload rendered video
video_blob_path = gcs_path(job_doc, language, "accessible_video.mp4")
video_blob = gcs_service.bucket.blob(video_blob_path)
video_blob.content_type = "video/mp4"
video_blob.upload_from_filename(output_video_path)
video_gcs_uri = f"gs://{settings.gcs_bucket}/{video_blob_path}"
logger.info(f"Uploaded re-rendered accessible video to {video_gcs_uri}")
# 8. Generate re-timed captions if pause-insert
retimed_captions_gcs_uri = None
if method == "pause_insert":
captions_vtt_gcs = lang_output.get("captions_vtt_gcs")
if captions_vtt_gcs:
captions_blob_path = captions_vtt_gcs.replace(f"gs://{settings.gcs_bucket}/", "")
captions_blob = gcs_service.bucket.blob(captions_blob_path)
original_captions_vtt = captions_blob.download_as_text()
retimed_captions = vtt_retimer_service.retime_for_pause_insert(
original_captions_vtt, analysis
)
retimed_blob_path = gcs_path(job_doc, language, "accessible_captions.vtt")
retimed_blob = gcs_service.bucket.blob(retimed_blob_path)
retimed_blob.content_type = "text/vtt"
retimed_blob.upload_from_string(retimed_captions, content_type="text/vtt")
retimed_captions_gcs_uri = f"gs://{settings.gcs_bucket}/{retimed_blob_path}"
logger.info(f"Uploaded re-timed captions to {retimed_captions_gcs_uri}")
# 9. Build new edit state
new_edit_state = None
if segment_metadata and new_pause_points:
new_edit_state = AccessibleVideoEditState(
pause_points=new_pause_points,
video_segments=segment_metadata,
tts_regeneration_queue=[],
last_render_at=datetime.utcnow(),
whisper_refine_enabled=whisper_refine
)
# 10. Update job document
update_fields = {
f"outputs.{language}.accessible_video_gcs": video_gcs_uri,
f"outputs.{language}.video_segments_gcs_prefix": f"gs://{settings.gcs_bucket}/{gcs_segment_prefix}",
"status": JobStatus.PENDING_QC.value,
"updated_at": datetime.utcnow()
}
if retimed_captions_gcs_uri:
update_fields[f"outputs.{language}.retimed_captions_vtt_gcs"] = retimed_captions_gcs_uri
if new_edit_state:
update_fields[f"outputs.{language}.accessible_video_edit_state"] = new_edit_state.model_dump()
completion_result = await db.jobs.update_one(
{"_id": job_id, "status": JobStatus.RENDERING_QC.value}, # Only complete if still rendering
{
"$set": update_fields,
"$push": {
"review.history": {
"at": datetime.utcnow(),
"status": JobStatus.PENDING_QC.value,
"by": "system",
"notes": f"Re-render complete for {language}"
}
}
}
)
if completion_result.modified_count == 0:
logger.warning(
f"Re-render completion update skipped for job {job_id}/{language}"
"job status changed during render (may have been cancelled or overridden)"
)
# Broadcast completion
broadcast_status_update(
job_id,
JobStatus.PENDING_QC.value,
job_title=job_title,
message=f"Accessible video re-render complete for {language.upper()}"
)
logger.info(f"Accessible video re-render complete for job {job_id}/{language}")
finally:
client.close()
async def _regenerate_tts_cues(
job_id: str,
language: str,
cue_indices: list[int],
job_doc: dict,
db,
temp_dir: str
) -> list[dict]:
"""
Regenerate TTS for specific cues using current VTT text.
Returns:
List of synthesis result dicts [{cue_index, gcs_uri, text, duration, success, ...}]
so the caller can update the ad_cue_manifest.
"""
logger.info(f"Regenerating TTS for {len(cue_indices)} cues")
# Get AD VTT content
lang_output = job_doc.get("outputs", {}).get(language)
ad_vtt_gcs = lang_output.get("ad_vtt_gcs")
ad_blob_path = ad_vtt_gcs.replace(f"gs://{settings.gcs_bucket}/", "")
ad_blob = gcs_service.bucket.blob(ad_blob_path)
ad_vtt_content = ad_blob.download_as_text()
# Parse cues
cues = parse_ad_cues(ad_vtt_content)
# Get TTS preferences
tts_preferences = job_doc["requested_outputs"].get("tts_preferences", {})
voices_per_language = tts_preferences.get("voices_per_language", {})
voice_name = voices_per_language.get(language, tts_preferences.get("default_voice"))
provider = tts_preferences.get("provider", "gemini")
model = tts_preferences.get("model", "flash")
speed = tts_preferences.get("speed", 1.0)
style_preset = tts_preferences.get("style_preset", "neutral")
custom_style_prompt = tts_preferences.get("custom_style_prompt")
if style_preset == "custom" and custom_style_prompt:
style_prompt = custom_style_prompt
else:
style_prompt = settings.gemini_tts_style_prompts.get(style_preset, "")
regen_results = []
# Synthesize each cue
for cue_idx in cue_indices:
if cue_idx >= len(cues):
logger.warning(f"Cue index {cue_idx} out of range, skipping")
continue
cue = cues[cue_idx]
logger.info(f"Synthesizing TTS for cue {cue_idx}: '{cue['text'][:50]}...'")
# Dispatch synthesis task
task_result = synthesize_cue_task.apply_async(
kwargs={
"job_id": job_id,
"language": language,
"cue_index": cue_idx,
"text": cue["text"],
"start_time": cue["start_time"],
"end_time": cue["end_time"],
"voice_name": voice_name,
"provider": provider,
"model": model,
"speed": speed,
"style_prompt": style_prompt,
"user_id": job_doc.get("client_id", "system"),
"cost_project_id": job_doc.get("cost_tracker_project_id"),
},
queue="tts"
)
# Wait for completion
poll_count = 0
while not task_result.ready():
await asyncio.sleep(1.0)
poll_count += 1
if poll_count % 30 == 0:
logger.info(f"Still waiting for TTS cue {cue_idx}...")
with allow_join_result():
result = task_result.get(timeout=120)
if not result.get("success"):
raise Exception(f"TTS synthesis failed for cue {cue_idx}: {result.get('error_message')}")
regen_results.append(result)
logger.info(f"TTS synthesis complete for cue {cue_idx}")
logger.info(f"All {len(cue_indices)} TTS cues regenerated")
return regen_results
def _build_placements_with_adjustments(
ad_vtt_content: str,
cue_durations: list[float],
pause_points: list[dict]
) -> list[dict]:
"""
Build placement instructions using adjusted pause points from QC edits.
Uses source_ms (source video coordinates) for pause point calculations,
applying user adjustments as relative offsets.
Args:
ad_vtt_content: AD VTT content
cue_durations: TTS durations per cue
pause_points: Pause point data with source_ms, original_ms, and adjusted values
Returns:
List of placement dicts
"""
cues = VTTParser.parse(ad_vtt_content)
# Build lookup of pause points by cue index using SOURCE coordinates
adjusted_pause_by_cue = {}
for pp in pause_points:
cue_idx = pp.get("cue_index")
source_ms = pp.get("source_ms")
original_ms = pp.get("original_ms")
adjusted_ms = pp.get("adjusted_ms")
# Fallback for data without source_ms (backward compatibility)
# When source_ms is missing we cannot reliably map rendered adjustments
# back to source coordinates, so we skip the delta and use original_ms as-is.
if source_ms is None:
logger.warning(
f"Cue {cue_idx}: No source_ms found, falling back to original_ms "
"without applying adjustment delta. "
"Job may need to be re-processed from initial render for timing adjustments to work."
)
# Use original_ms directly; skip adjustment to avoid double-counting
pause_time_s = original_ms / 1000.0 if original_ms is not None else cues[cue_idx].start_time if cue_idx < len(cues) else 0.0
adjusted_pause_by_cue[cue_idx] = pause_time_s
continue
# Apply user adjustment as relative offset in source coordinates
if adjusted_ms is not None and original_ms is not None:
# User adjusted in rendered timeline - apply same delta to source
adjustment_delta = adjusted_ms - original_ms
adjusted_source_ms = source_ms + adjustment_delta
logger.info(
f"Cue {cue_idx}: Applying adjustment delta {adjustment_delta:.1f}ms "
f"(rendered: {original_ms:.1f} -> {adjusted_ms:.1f}, "
f"source: {source_ms:.1f} -> {adjusted_source_ms:.1f})"
)
source_ms = adjusted_source_ms
# Convert to seconds for placement
pause_time_s = source_ms / 1000.0
adjusted_pause_by_cue[cue_idx] = pause_time_s
placements = []
for i, cue in enumerate(cues):
if i >= len(cue_durations):
break
# Get pause point: use source-based value if available, otherwise fall back to VTT
pause_point = adjusted_pause_by_cue.get(i, cue.start_time)
placements.append({
"ad_cue_index": i,
"original_start_time": cue.start_time,
"original_end_time": cue.end_time,
"target_start_time": cue.start_time,
"ad_duration": cue_durations[i],
"pause_point": pause_point,
"resume_from": pause_point,
"pause_point_rationale": "User-adjusted during QC" if i in adjusted_pause_by_cue else "Original from VTT"
})
# Enforce pause_point monotonicity - pause_points must be non-decreasing in cue order.
# User-adjusted pause points can cross over each other; clamp to maintain cue order.
for i in range(1, len(placements)):
prev_pp = placements[i - 1].get("pause_point")
curr_pp = placements[i].get("pause_point")
if curr_pp is not None and prev_pp is not None and curr_pp < prev_pp:
logger.warning(
f"Rerender monotonicity fix: cue {placements[i].get('ad_cue_index')} "
f"pause_point {curr_pp:.2f}s < cue {placements[i-1].get('ad_cue_index')} "
f"pause_point {prev_pp:.2f}s, clamping to {prev_pp:.2f}s"
)
placements[i]["pause_point"] = prev_pp
placements[i]["resume_from"] = prev_pp
return placements
async def _refine_pause_points_for_rerender(
job_id: str,
video_path: str,
analysis: dict,
db,
temp_dir: str
) -> tuple[dict, list[str]]:
"""Run Whisper pause point refinement for re-render."""
logger.info(f"Refining pause points with Whisper for re-render of job {job_id}")
audio_path = os.path.join(temp_dir, "source_audio.mp3")
await _extract_audio_for_whisper(video_path, audio_path)
try:
words = await _dispatch_whisper_transcription(job_id, audio_path)
except Exception as e:
logger.error(f"Whisper transcription failed: {e}")
return analysis, [f"Whisper failed: {str(e)} - using current timestamps"]
if not words:
return analysis, ["No speech detected - using current timestamps"]
gaps = whisper_service.identify_speech_gaps(words)
refined_placements, warnings = whisper_service.refine_all_pause_points(
analysis.get("placements", []),
words,
gaps
)
refined_analysis = analysis.copy()
refined_analysis["placements"] = refined_placements
refined_analysis["whisper_refined"] = True
return refined_analysis, warnings