"""Celery task for re-rendering accessible video with QC changes.""" import asyncio import io import os import tempfile from datetime import datetime from celery.result import allow_join_result from motor.motor_asyncio import AsyncIOMotorClient from pydub import AudioSegment from ..core.config import settings from ..core.logging import get_logger from ..lib.vtt import VTTParser from ..models.job import AccessibleVideoEditState, JobStatus, PausePointData, VideoSegmentMetadata from ..services.gcs import gcs_service from ..services.video_renderer import video_renderer_service from ..services.vtt_retimer import vtt_retimer_service from ..services.whisper_service import WordTimestamp, whisper_service from . import celery_app from .render_accessible_video import _extract_audio_for_whisper, _dispatch_whisper_transcription from .translate_and_synthesize import broadcast_status_update from .tts_synthesis import dispatch_language_tts, parse_ad_cues, parse_cue_index_from_blob_name, synthesize_cue_task logger = get_logger(__name__) @celery_app.task(bind=True, time_limit=7200, soft_time_limit=7000) def rerender_accessible_video_task( self, job_id: str, language: str, regenerate_cue_indices: list[int], whisper_refine: bool = False ): """ Re-render accessible video during QC review with selective TTS regeneration. This task: 1. If regenerate_cue_indices not empty: synthesize new TTS for those cues 2. Download source video and existing segments/MP3s 3. If whisper_refine: run Whisper pause point refinement 4. Re-render video using updated pause points and new/existing TTS 5. Update job status back to PENDING_QC Args: job_id: Job ID language: Language being re-rendered regenerate_cue_indices: List of cue indices to regenerate TTS for whisper_refine: Whether to run Whisper pause point refinement """ logger.info( f"Starting accessible video re-render for job {job_id}/{language}: " f"regenerate={regenerate_cue_indices}, whisper_refine={whisper_refine}" ) try: result = asyncio.run(_async_rerender_accessible_video( job_id, language, regenerate_cue_indices, whisper_refine )) logger.info(f"Accessible video re-render completed for job {job_id}/{language}") return result except Exception as e: logger.error(f"Accessible video re-render failed for job {job_id}/{language}: {e}") import traceback logger.error(f"Full traceback: {traceback.format_exc()}") # Update job status back to PENDING_QC with error asyncio.run(_mark_rerender_failed(job_id, language, str(e))) raise async def _mark_rerender_failed(job_id: str, language: str, error_message: str): """Mark re-render as failed and return to PENDING_QC.""" client = AsyncIOMotorClient(settings.mongodb_uri) db = client[settings.mongodb_db] try: await db.jobs.update_one( {"_id": job_id}, { "$set": { "status": JobStatus.PENDING_QC.value, f"outputs.{language}.accessible_video_edit_state.last_render_error": error_message, "updated_at": datetime.utcnow() }, "$push": { "review.history": { "at": datetime.utcnow(), "status": JobStatus.PENDING_QC.value, "by": "system", "notes": f"Re-render failed for {language}: {error_message[:200]}" } } } ) job_doc = await db.jobs.find_one({"_id": job_id}) broadcast_status_update( job_id, JobStatus.PENDING_QC.value, job_title=job_doc.get("title") if job_doc else None, message=f"Re-render failed: {error_message[:100]}" ) finally: client.close() async def _async_rerender_accessible_video( job_id: str, language: str, regenerate_cue_indices: list[int], whisper_refine: bool ): """Async implementation of accessible video re-rendering.""" logger.info(f"Async re-render started for job {job_id}/{language}") client = AsyncIOMotorClient(settings.mongodb_uri) db = client[settings.mongodb_db] try: # Get job details job_doc = await db.jobs.find_one({"_id": job_id}) if not job_doc: raise ValueError(f"Job {job_id} not found") job_title = job_doc.get("title", "Untitled Job") lang_output = job_doc.get("outputs", {}).get(language) if not lang_output: raise ValueError(f"No outputs found for language {language}") edit_state = lang_output.get("accessible_video_edit_state") if not edit_state: raise ValueError(f"No edit state found for language {language}") # Use TMPDIR env var if set temp_base = os.environ.get('TMPDIR', None) with tempfile.TemporaryDirectory(dir=temp_base) as temp_dir: # 1. Download source video source_video_gcs = job_doc["source"]["gcs_uri"] source_blob_path = source_video_gcs.replace(f"gs://{settings.gcs_bucket}/", "") source_video_path = os.path.join(temp_dir, "source.mp4") logger.info(f"Downloading source video from {source_blob_path}") source_blob = gcs_service.bucket.blob(source_blob_path) source_blob.download_to_filename(source_video_path) # 2. Regenerate TTS for queued cues (if any) if regenerate_cue_indices: logger.info(f"Regenerating TTS for cues: {regenerate_cue_indices}") regen_results = await _regenerate_tts_cues( job_id, language, regenerate_cue_indices, job_doc, db, temp_dir ) # Update manifest with new GCS URIs for regenerated cues if regen_results: job_doc_after_regen = await db.jobs.find_one({"_id": job_id}) current_manifest = ( job_doc_after_regen.get("outputs", {}) .get(language, {}) .get("ad_cue_manifest") or [] ) manifest_by_idx = {e["cue_index"]: e for e in current_manifest} for r in regen_results: if r.get("success") and r.get("gcs_uri"): manifest_by_idx[r["cue_index"]] = { "cue_index": r["cue_index"], "gcs_uri": r["gcs_uri"], "text": r.get("text", "")[:80], "duration_s": r.get("duration", 0.0) } updated_manifest = sorted(manifest_by_idx.values(), key=lambda e: e["cue_index"]) await db.jobs.update_one( {"_id": job_id}, {"$set": {f"outputs.{language}.ad_cue_manifest": updated_manifest}} ) logger.info(f"Updated ad_cue_manifest with {len(regen_results)} regenerated cues") # Clear regeneration queue after successful synthesis await db.jobs.update_one( {"_id": job_id}, { "$set": { f"outputs.{language}.accessible_video_edit_state.tts_regeneration_queue": [], "updated_at": datetime.utcnow() } } ) # 3. Download AD VTT and per-cue MP3s ad_vtt_gcs = lang_output.get("ad_vtt_gcs") if not ad_vtt_gcs: raise ValueError(f"No AD VTT found for language {language}") ad_blob_path = ad_vtt_gcs.replace(f"gs://{settings.gcs_bucket}/", "") ad_blob = gcs_service.bucket.blob(ad_blob_path) ad_vtt_content = ad_blob.download_as_text() # Download per-cue MP3s ad_cues_prefix = lang_output.get("ad_cues_gcs_prefix") if not ad_cues_prefix: raise ValueError(f"No AD cue segments found for language {language}") ad_segments = [] cue_durations = [] # Re-fetch job doc to get updated manifest after TTS regen job_doc = await db.jobs.find_one({"_id": job_id}) lang_output = job_doc["outputs"].get(language, {}) ad_cue_manifest = lang_output.get("ad_cue_manifest") if ad_cue_manifest: logger.info(f"Using ad_cue_manifest ({len(ad_cue_manifest)} entries) for MP3 download") for entry in sorted(ad_cue_manifest, key=lambda e: e["cue_index"]): cue_index = entry["cue_index"] gcs_uri = entry["gcs_uri"] blob_path = gcs_uri.replace(f"gs://{settings.gcs_bucket}/", "") local_path = os.path.join(temp_dir, f"cue_{cue_index}.mp3") gcs_service.bucket.blob(blob_path).download_to_filename(local_path) ad_segments.append((cue_index, local_path)) audio = AudioSegment.from_mp3(local_path) cue_durations.append(len(audio) / 1000.0) else: logger.warning( f"No ad_cue_manifest for job {job_id}/{language} — " "falling back to legacy index-based blob listing. " "Cue insertions/deletions may cause MP3/VTT desync." ) prefix_path = ad_cues_prefix.replace(f"gs://{settings.gcs_bucket}/", "") blobs = list(gcs_service.bucket.list_blobs(prefix=prefix_path)) cue_blobs = [ (b, parse_cue_index_from_blob_name(b.name)) for b in blobs if b.name.endswith(".mp3") ] cue_blobs = [(b, idx) for b, idx in cue_blobs if idx is not None] cue_blobs.sort(key=lambda x: x[1]) for blob, cue_index in cue_blobs: local_path = os.path.join(temp_dir, f"cue_{cue_index}.mp3") blob.download_to_filename(local_path) ad_segments.append((cue_index, local_path)) audio = AudioSegment.from_mp3(local_path) cue_durations.append(len(audio) / 1000.0) logger.info(f"Downloaded {len(ad_segments)} AD cue segments") # Validate VTT cue count matches MP3 count vtt_cues = VTTParser.parse(ad_vtt_content) downloaded_indices = set(idx for idx, _ in ad_segments) if len(vtt_cues) != len(ad_segments): missing_indices = set(range(len(vtt_cues))) - downloaded_indices logger.warning( f"VTT cue count ({len(vtt_cues)}) does not match MP3 count ({len(ad_segments)}). " f"Missing MP3s for cue indices: {sorted(missing_indices)}. " f"This may happen when a new AD cue is inserted but TTS wasn't regenerated for shifted cues." ) # 4. Build placements with adjusted pause points method = lang_output.get("accessible_video_method", "pause_insert") pause_points = edit_state.get("pause_points", []) if not pause_points: logger.info("No pause points in edit state — using VTT cue start times for all placements") placements = _build_placements_with_adjustments( ad_vtt_content, cue_durations, pause_points ) logger.info(f"Built {len(placements)} placements with adjusted pause points") analysis = { "method": method, "method_rationale": "QC re-render with user adjustments", "placements": placements, "total_added_duration": sum(cue_durations) if method == "pause_insert" else 0, "warnings": [] } # 5. Optionally run Whisper refinement if whisper_refine and method == "pause_insert": logger.info("Running Whisper pause point refinement...") analysis, whisper_warnings = await _refine_pause_points_for_rerender( job_id, source_video_path, analysis, db, temp_dir ) if whisper_warnings: analysis["warnings"] = analysis.get("warnings", []) + whisper_warnings logger.info(f"Whisper refinement complete with {len(whisper_warnings)} warnings") # 6. Render accessible video (persist segments again for future edits) output_video_path = os.path.join(temp_dir, "accessible_video.mp4") gcs_segment_prefix = f"{job_id}/{language}/segments/" logger.info(f"Re-rendering accessible video using {method} method...") rendered_path, updated_placements, segment_metadata, new_pause_points = await video_renderer_service.render_accessible_video( source_video_path, ad_segments, analysis, output_video_path, persist_segments=True, gcs_segment_prefix=gcs_segment_prefix ) if updated_placements: analysis["placements"] = updated_placements # 7. Upload rendered video video_blob_path = f"{job_id}/{language}/accessible_video.mp4" video_blob = gcs_service.bucket.blob(video_blob_path) video_blob.content_type = "video/mp4" video_blob.upload_from_filename(output_video_path) video_gcs_uri = f"gs://{settings.gcs_bucket}/{video_blob_path}" logger.info(f"Uploaded re-rendered accessible video to {video_gcs_uri}") # 8. Generate re-timed captions if pause-insert retimed_captions_gcs_uri = None if method == "pause_insert": captions_vtt_gcs = lang_output.get("captions_vtt_gcs") if captions_vtt_gcs: captions_blob_path = captions_vtt_gcs.replace(f"gs://{settings.gcs_bucket}/", "") captions_blob = gcs_service.bucket.blob(captions_blob_path) original_captions_vtt = captions_blob.download_as_text() retimed_captions = vtt_retimer_service.retime_for_pause_insert( original_captions_vtt, analysis ) retimed_blob_path = f"{job_id}/{language}/accessible_captions.vtt" retimed_blob = gcs_service.bucket.blob(retimed_blob_path) retimed_blob.content_type = "text/vtt" retimed_blob.upload_from_string(retimed_captions, content_type="text/vtt") retimed_captions_gcs_uri = f"gs://{settings.gcs_bucket}/{retimed_blob_path}" logger.info(f"Uploaded re-timed captions to {retimed_captions_gcs_uri}") # 9. Build new edit state new_edit_state = None if segment_metadata and new_pause_points: new_edit_state = AccessibleVideoEditState( pause_points=new_pause_points, video_segments=segment_metadata, tts_regeneration_queue=[], last_render_at=datetime.utcnow(), whisper_refine_enabled=whisper_refine ) # 10. Update job document update_fields = { f"outputs.{language}.accessible_video_gcs": video_gcs_uri, f"outputs.{language}.video_segments_gcs_prefix": f"gs://{settings.gcs_bucket}/{gcs_segment_prefix}", "status": JobStatus.PENDING_QC.value, "updated_at": datetime.utcnow() } if retimed_captions_gcs_uri: update_fields[f"outputs.{language}.retimed_captions_vtt_gcs"] = retimed_captions_gcs_uri if new_edit_state: update_fields[f"outputs.{language}.accessible_video_edit_state"] = new_edit_state.model_dump() completion_result = await db.jobs.update_one( {"_id": job_id, "status": JobStatus.RENDERING_QC.value}, # Only complete if still rendering { "$set": update_fields, "$push": { "review.history": { "at": datetime.utcnow(), "status": JobStatus.PENDING_QC.value, "by": "system", "notes": f"Re-render complete for {language}" } } } ) if completion_result.modified_count == 0: logger.warning( f"Re-render completion update skipped for job {job_id}/{language} — " "job status changed during render (may have been cancelled or overridden)" ) # Broadcast completion broadcast_status_update( job_id, JobStatus.PENDING_QC.value, job_title=job_title, message=f"Accessible video re-render complete for {language.upper()}" ) logger.info(f"Accessible video re-render complete for job {job_id}/{language}") finally: client.close() async def _regenerate_tts_cues( job_id: str, language: str, cue_indices: list[int], job_doc: dict, db, temp_dir: str ) -> list[dict]: """ Regenerate TTS for specific cues using current VTT text. Returns: List of synthesis result dicts [{cue_index, gcs_uri, text, duration, success, ...}] so the caller can update the ad_cue_manifest. """ logger.info(f"Regenerating TTS for {len(cue_indices)} cues") # Get AD VTT content lang_output = job_doc.get("outputs", {}).get(language) ad_vtt_gcs = lang_output.get("ad_vtt_gcs") ad_blob_path = ad_vtt_gcs.replace(f"gs://{settings.gcs_bucket}/", "") ad_blob = gcs_service.bucket.blob(ad_blob_path) ad_vtt_content = ad_blob.download_as_text() # Parse cues cues = parse_ad_cues(ad_vtt_content) # Get TTS preferences tts_preferences = job_doc["requested_outputs"].get("tts_preferences", {}) voices_per_language = tts_preferences.get("voices_per_language", {}) voice_name = voices_per_language.get(language, tts_preferences.get("default_voice")) provider = tts_preferences.get("provider", "gemini") model = tts_preferences.get("model", "flash") speed = tts_preferences.get("speed", 1.0) style_preset = tts_preferences.get("style_preset", "neutral") custom_style_prompt = tts_preferences.get("custom_style_prompt") if style_preset == "custom" and custom_style_prompt: style_prompt = custom_style_prompt else: style_prompt = settings.gemini_tts_style_prompts.get(style_preset, "") regen_results = [] # Synthesize each cue for cue_idx in cue_indices: if cue_idx >= len(cues): logger.warning(f"Cue index {cue_idx} out of range, skipping") continue cue = cues[cue_idx] logger.info(f"Synthesizing TTS for cue {cue_idx}: '{cue['text'][:50]}...'") # Dispatch synthesis task task_result = synthesize_cue_task.apply_async( kwargs={ "job_id": job_id, "language": language, "cue_index": cue_idx, "text": cue["text"], "start_time": cue["start_time"], "end_time": cue["end_time"], "voice_name": voice_name, "provider": provider, "model": model, "speed": speed, "style_prompt": style_prompt, "user_id": job_doc.get("client_id", "system"), "cost_project_id": job_doc.get("cost_tracker_project_id"), }, queue="tts" ) # Wait for completion poll_count = 0 while not task_result.ready(): await asyncio.sleep(1.0) poll_count += 1 if poll_count % 30 == 0: logger.info(f"Still waiting for TTS cue {cue_idx}...") with allow_join_result(): result = task_result.get(timeout=120) if not result.get("success"): raise Exception(f"TTS synthesis failed for cue {cue_idx}: {result.get('error_message')}") regen_results.append(result) logger.info(f"TTS synthesis complete for cue {cue_idx}") logger.info(f"All {len(cue_indices)} TTS cues regenerated") return regen_results def _build_placements_with_adjustments( ad_vtt_content: str, cue_durations: list[float], pause_points: list[dict] ) -> list[dict]: """ Build placement instructions using adjusted pause points from QC edits. Uses source_ms (source video coordinates) for pause point calculations, applying user adjustments as relative offsets. Args: ad_vtt_content: AD VTT content cue_durations: TTS durations per cue pause_points: Pause point data with source_ms, original_ms, and adjusted values Returns: List of placement dicts """ cues = VTTParser.parse(ad_vtt_content) # Build lookup of pause points by cue index using SOURCE coordinates adjusted_pause_by_cue = {} for pp in pause_points: cue_idx = pp.get("cue_index") source_ms = pp.get("source_ms") original_ms = pp.get("original_ms") adjusted_ms = pp.get("adjusted_ms") # Fallback for data without source_ms (backward compatibility) # When source_ms is missing we cannot reliably map rendered adjustments # back to source coordinates, so we skip the delta and use original_ms as-is. if source_ms is None: logger.warning( f"Cue {cue_idx}: No source_ms found, falling back to original_ms " "without applying adjustment delta. " "Job may need to be re-processed from initial render for timing adjustments to work." ) # Use original_ms directly; skip adjustment to avoid double-counting pause_time_s = original_ms / 1000.0 if original_ms is not None else cues[cue_idx].start_time if cue_idx < len(cues) else 0.0 adjusted_pause_by_cue[cue_idx] = pause_time_s continue # Apply user adjustment as relative offset in source coordinates if adjusted_ms is not None and original_ms is not None: # User adjusted in rendered timeline - apply same delta to source adjustment_delta = adjusted_ms - original_ms adjusted_source_ms = source_ms + adjustment_delta logger.info( f"Cue {cue_idx}: Applying adjustment delta {adjustment_delta:.1f}ms " f"(rendered: {original_ms:.1f} -> {adjusted_ms:.1f}, " f"source: {source_ms:.1f} -> {adjusted_source_ms:.1f})" ) source_ms = adjusted_source_ms # Convert to seconds for placement pause_time_s = source_ms / 1000.0 adjusted_pause_by_cue[cue_idx] = pause_time_s placements = [] for i, cue in enumerate(cues): if i >= len(cue_durations): break # Get pause point: use source-based value if available, otherwise fall back to VTT pause_point = adjusted_pause_by_cue.get(i, cue.start_time) placements.append({ "ad_cue_index": i, "original_start_time": cue.start_time, "original_end_time": cue.end_time, "target_start_time": cue.start_time, "ad_duration": cue_durations[i], "pause_point": pause_point, "resume_from": pause_point, "pause_point_rationale": "User-adjusted during QC" if i in adjusted_pause_by_cue else "Original from VTT" }) # Enforce pause_point monotonicity - pause_points must be non-decreasing in cue order. # User-adjusted pause points can cross over each other; clamp to maintain cue order. for i in range(1, len(placements)): prev_pp = placements[i - 1].get("pause_point") curr_pp = placements[i].get("pause_point") if curr_pp is not None and prev_pp is not None and curr_pp < prev_pp: logger.warning( f"Rerender monotonicity fix: cue {placements[i].get('ad_cue_index')} " f"pause_point {curr_pp:.2f}s < cue {placements[i-1].get('ad_cue_index')} " f"pause_point {prev_pp:.2f}s, clamping to {prev_pp:.2f}s" ) placements[i]["pause_point"] = prev_pp placements[i]["resume_from"] = prev_pp return placements async def _refine_pause_points_for_rerender( job_id: str, video_path: str, analysis: dict, db, temp_dir: str ) -> tuple[dict, list[str]]: """Run Whisper pause point refinement for re-render.""" logger.info(f"Refining pause points with Whisper for re-render of job {job_id}") audio_path = os.path.join(temp_dir, "source_audio.mp3") await _extract_audio_for_whisper(video_path, audio_path) try: words = await _dispatch_whisper_transcription(job_id, audio_path) except Exception as e: logger.error(f"Whisper transcription failed: {e}") return analysis, [f"Whisper failed: {str(e)} - using current timestamps"] if not words: return analysis, ["No speech detected - using current timestamps"] gaps = whisper_service.identify_speech_gaps(words) refined_placements, warnings = whisper_service.refine_all_pause_points( analysis.get("placements", []), words, gaps ) refined_analysis = analysis.copy() refined_analysis["placements"] = refined_placements refined_analysis["whisper_refined"] = True return refined_analysis, warnings