video-accessibility/backend/app/tasks/whisper_transcribe.py
michael 7d2366d0f4 fix: add authentication for Cloud Run service calls
Cloud Run services are deployed with --no-allow-unauthenticated,
requiring an ID token in the Authorization header.

- Add _get_cloud_run_id_token() helper using google-auth library
- Update whisper_transcribe.py to include Bearer token in Cloud Run calls
- Update video_renderer.py to include Bearer token in FFmpeg Cloud Run calls

The ID token is fetched using the service account credentials
(GOOGLE_APPLICATION_CREDENTIALS) and targets the Cloud Run service URL.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-02 11:41:07 -06:00

197 lines
6.1 KiB
Python

"""Celery task for Whisper transcription with Cloud Run fallback."""
import os
import uuid
import google.auth.transport.requests
import httpx
from google.auth import default
from google.cloud import storage
from google.oauth2 import id_token
from ..core.config import settings
from ..core.logging import get_logger
from ..services.whisper_service import whisper_service
from . import celery_app
logger = get_logger(__name__)
def _get_cloud_run_id_token(audience: str) -> str:
"""
Get an ID token for authenticating to Cloud Run services.
Uses the service account credentials to generate an ID token
that Cloud Run will accept for authentication.
"""
# Get credentials from the environment (GOOGLE_APPLICATION_CREDENTIALS)
credentials, _ = default()
# Create a request object for token refresh
request = google.auth.transport.requests.Request()
# Fetch an ID token for the target audience (the Cloud Run service URL)
token = id_token.fetch_id_token(request, audience)
return token
def _upload_audio_to_gcs_temp(audio_path: str, job_id: str) -> str:
"""Upload local audio file to GCS temporary location and return GCS URI."""
# Generate unique temp path
filename = os.path.basename(audio_path)
temp_path = f"temp/whisper/{job_id}/{uuid.uuid4().hex}/{filename}"
client = storage.Client(project=settings.gcp_project_id)
bucket = client.bucket(settings.gcs_bucket)
blob = bucket.blob(temp_path)
blob.upload_from_filename(audio_path)
gcs_uri = f"gs://{settings.gcs_bucket}/{temp_path}"
logger.info(f"Uploaded audio to temp GCS: {gcs_uri}")
return gcs_uri
def _delete_gcs_temp(gcs_uri: str) -> None:
"""Delete temporary GCS file."""
try:
if not gcs_uri.startswith("gs://"):
return
parts = gcs_uri[5:].split("/", 1)
if len(parts) != 2:
return
bucket_name, blob_path = parts
client = storage.Client(project=settings.gcp_project_id)
bucket = client.bucket(bucket_name)
blob = bucket.blob(blob_path)
blob.delete()
logger.info(f"Deleted temp GCS file: {gcs_uri}")
except Exception as e:
logger.warning(f"Failed to delete temp GCS file {gcs_uri}: {e}")
def _transcribe_via_cloud_run(job_id: str, audio_path: str) -> dict:
"""
Transcribe audio via Cloud Run Whisper service.
Uploads local audio to GCS temp, calls Cloud Run, then cleans up.
"""
gcs_uri = None
try:
# Upload audio to GCS temp location
gcs_uri = _upload_audio_to_gcs_temp(audio_path, job_id)
# Call Cloud Run service
service_url = settings.whisper_service_url.rstrip("/")
endpoint = f"{service_url}/transcribe"
logger.info(f"Calling Whisper Cloud Run service: {endpoint}")
# Get ID token for Cloud Run authentication
id_token = _get_cloud_run_id_token(service_url)
with httpx.Client(timeout=300.0) as client:
response = client.post(
endpoint,
json={"gcs_uri": gcs_uri},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {id_token}"
}
)
response.raise_for_status()
data = response.json()
# Calculate audio duration from last word
words = data.get("words", [])
audio_duration = words[-1]["end"] if words else 0.0
result = {
"job_id": job_id,
"word_count": data.get("word_count", len(words)),
"audio_duration": audio_duration,
"words": words
}
logger.info(
f"Cloud Run transcription complete for job {job_id}: "
f"{result['word_count']} words, {audio_duration:.2f}s duration"
)
return result
finally:
# Clean up temp GCS file
if gcs_uri:
_delete_gcs_temp(gcs_uri)
def _transcribe_locally(job_id: str, audio_path: str) -> dict:
"""Transcribe audio using local Whisper service."""
words = whisper_service.transcribe_audio(audio_path)
# Convert to serializable format
words_data = [w.to_dict() for w in words]
# Calculate audio duration from last word
audio_duration = words[-1].end if words else 0.0
result = {
"job_id": job_id,
"word_count": len(words),
"audio_duration": audio_duration,
"words": words_data
}
logger.info(
f"Local transcription complete for job {job_id}: "
f"{len(words)} words, {audio_duration:.2f}s duration"
)
return result
@celery_app.task(bind=True, queue='whisper', time_limit=1800, soft_time_limit=1700)
def transcribe_video_audio_task(self, job_id: str, audio_path: str) -> dict:
"""
Run Whisper transcription - via Cloud Run if configured, otherwise locally.
When WHISPER_SERVICE_URL is set, transcription is offloaded to Cloud Run
for autoscaling. Otherwise, runs on dedicated local worker with concurrency=1.
Args:
job_id: Job ID for logging
audio_path: Path to the extracted audio file (MP3, WAV, etc.)
Returns:
Dict with word timestamps as serializable data:
{
"job_id": str,
"word_count": int,
"audio_duration": float,
"words": [{"word": str, "start": float, "end": float}, ...]
}
"""
logger.info(f"Starting Whisper transcription task for job {job_id}")
try:
# Use Cloud Run if configured, otherwise local
if settings.whisper_service_url:
logger.info(f"Using Cloud Run Whisper service: {settings.whisper_service_url}")
return _transcribe_via_cloud_run(job_id, audio_path)
else:
logger.info("Using local Whisper service")
return _transcribe_locally(job_id, audio_path)
except httpx.HTTPStatusError as e:
logger.error(f"Cloud Run transcription failed for job {job_id}: {e.response.status_code} - {e.response.text}")
raise
except Exception as e:
logger.error(f"Whisper transcription failed for job {job_id}: {e}")
raise