forge/backend/app/services/video_upscaler.py

371 lines
16 KiB
Python

"""Video Upscaler Service - Topaz Labs API"""
import httpx
import os
from uuid import uuid4
from datetime import datetime
import asyncio
from app.database import SessionLocal
from app.models.job import Job
from app.models.asset import Asset
from app.config import settings
import logging
logger = logging.getLogger(__name__)
# Topaz Video AI Models Mapping
VIDEO_MODELS = {
"Proteus": "prob-4",
"Artemis High Quality": "ahq-12",
"Artemis Medium Quality": "amq-13",
"Artemis Low Quality": "alq-13",
"Gaia High Quality": "ghq-5",
"Gaia CG": "gcg-5",
"Theia Detail": "thd-3",
"Theia Fidelity": "thf-4",
"Nyx": "nyx-3",
"Nyx Fast": "nxf-1",
"Dione DV": "ddv-3",
"Dione TV": "dtv-4",
"Iris": "iris-2", # Updated from iris-1 to valid iris-2
"Auto": "prob-4" # Fallback/Default
}
async def upscale(job_id: str):
"""Upscale video using Topaz Labs API"""
db = SessionLocal()
try:
job = db.query(Job).filter(Job.id == job_id).first()
if not job:
return
input_data = job.input_data
input_asset_ids = job.input_asset_ids
if not input_asset_ids:
raise ValueError("No input asset provided")
input_asset = db.query(Asset).filter(Asset.id == input_asset_ids[0]).first()
if not input_asset:
raise ValueError("Input asset not found")
job.progress = 5
job.api_provider = "topaz"
job.api_model = input_data.get("model", "auto")
db.commit()
scale = input_data.get("scale", 2)
model = input_data.get("model", "auto")
frame_interpolation = input_data.get("frame_interpolation", 1)
# New parameters
fps = input_data.get("fps")
sharpening = input_data.get("sharpening")
recover_detail = input_data.get("recover_detail")
add_noise = input_data.get("add_noise")
video_type = input_data.get("video_type", "Progressive")
face_enhancement = input_data.get("face_enhancement", False)
# Get video metadata with ffprobe
from app.utils.video import extract_video_metadata
metadata = extract_video_metadata(input_asset.file_path)
# Use extracted metadata or fallback to asset record
duration = metadata.get('duration_seconds') or float(input_asset.duration_seconds or 10)
source_fps = metadata.get('fps') or 30
width = metadata.get('width') or input_asset.width or 1920
height = metadata.get('height') or input_asset.height or 1080
video_info = {
"container": "mp4",
"size": input_asset.file_size_bytes,
"duration": duration,
"frameCount": int(duration * source_fps),
"frameRate": source_fps,
"resolution": {
"width": width,
"height": height
}
}
output_width = video_info["resolution"]["width"] * scale
output_height = video_info["resolution"]["height"] * scale
video_type = input_data.get("video_type", "progressive")
face_enhancement = input_data.get("face_enhancement", False)
# Determine target FPS
target_fps = fps if fps else (video_info["frameRate"] * frame_interpolation)
job.progress = 10
db.commit()
async with httpx.AsyncClient(timeout=1800) as client:
# Build filters
filters = []
# 1. Enhancement filter
# Logic: If face_enhancement is True, strictly use 'iris-2' (Iris).
# Otherwise, lookup model in VIDEO_MODELS. If not found, default to 'prob-4' (Proteus).
# Build filters logic matching PHP structure exactly (from estimate.php)
selected_model_code = "prob-4"
if face_enhancement:
selected_model_code = "iris-2"
else:
selected_model_code = VIDEO_MODELS.get(model, "prob-4")
# PHP used "Progressive" (User's working code)
video_type_val = video_type.capitalize() if video_type else "Progressive"
enhance_filter = {
"model": selected_model_code,
"videoType": video_type_val,
"auto": "Auto",
"fieldOrder": "Auto", # Added from estimate.php
"focusFixLevel": "None", # Added from estimate.php
"blur": 0.0, # Default from estimate.php
"grain": 0.0, # Default from estimate.php
"grainSize": 1.5, # Default from estimate.php
"recoverOriginalDetailValue": 0.2 # Default from estimate.php
}
# Override defaults with inputs if present
if sharpening is not None:
# Note: estimate.php doesn't map 'sharpening' to 'details', it uses 'recoverOriginalDetailValue'
# But typically 'details' is the param. Let's stick to valid defaults from PHP first.
# Actually estimate.php uses $_POST['recoverDetail'] -> recoverOriginalDetailValue.
pass
if recover_detail is not None:
enhance_filter["recoverOriginalDetailValue"] = int(recover_detail) / 100.0 if int(recover_detail) > 1 else int(recover_detail)
# Map other UI sliders if we want, but let's stick to working PHP defaults + Model for now to FIX it.
filters.append(enhance_filter)
# Create video enhancement request - match PHP keys layout
payload = {
"source": video_info,
"filters": filters,
"output": {
"resolution": {
"width": output_width,
"height": output_height
},
"frameRate": target_fps,
"audioCodec": "AAC",
"audioTransfer": "Copy",
# Added missing fields from estimate.php
"videoEncoder": "H265",
"videoBitrate": "6000k",
"videoProfile": "Main",
"cropToFit": False,
"container": "mp4"
}
}
print(f"DEBUG: Topaz Video Payload: {payload}")
# Revert to root /video endpoint as /v1 returned 404
response = await client.post(
"https://api.topazlabs.com/video/",
headers={
"X-API-Key": settings.topaz_api_key,
"Content-Type": "application/json",
"Accept": "application/json"
},
json=payload
)
if response.status_code >= 400:
logger.error(f"Topaz Video API Error: {response.text}")
response.raise_for_status()
result = response.json()
logger.info(f"Topaz Video Creation Response: {result}")
request_id = result.get("requestId")
if not request_id:
raise ValueError(f"No requestId returned from Topaz: {result}")
job.progress = 15
job.api_request_id = request_id
db.commit()
# Accept the request and get upload URLs
accept_response = await client.patch(
f"https://api.topazlabs.com/video/{request_id}/accept",
headers={"X-API-Key": settings.topaz_api_key}
)
logger.info(f"Topaz Video Accept Response: {accept_response.text}")
accept_response.raise_for_status()
accept_data = accept_response.json()
upload_urls = accept_data.get("urls", [])
job.progress = 20
db.commit()
# Upload video file in parts
with open(input_asset.file_path, "rb") as f:
video_data = f.read()
part_size = len(video_data) // len(upload_urls) if upload_urls else len(video_data)
upload_results = []
for i, url in enumerate(upload_urls):
start = i * part_size
end = start + part_size if i < len(upload_urls) - 1 else len(video_data)
part_data = video_data[start:end]
upload_response = await client.put(
url,
content=part_data,
headers={"Content-Type": "application/octet-stream"}
)
etag = upload_response.headers.get("ETag", "").strip('"')
upload_results.append({
"partNum": i + 1,
"eTag": etag
})
job.progress = 20 + (i + 1) * (30 / len(upload_urls))
db.commit()
# Complete the upload
complete_response = await client.patch(
f"https://api.topazlabs.com/video/{request_id}/complete-upload",
headers={
"X-API-Key": settings.topaz_api_key,
"Content-Type": "application/json"
},
json={"uploadResults": upload_results}
)
logger.info(f"Topaz Video Complete Upload Response: {complete_response.text}")
complete_response.raise_for_status()
job.progress = 50
db.commit()
# Poll for completion
output_asset = None
# Poll for completion
output_asset = None
for i in range(900): # Wait up to 30 minutes (2s * 900)
await asyncio.sleep(2)
try:
# Generic logging for debugging
if i % 10 == 0:
logger.info(f"Polling Topaz Video Job {request_id} (Attempt {i})")
# Trying generic resource URL first (some APIs use GET /resource/{id} for status)
status_url = f"https://api.topazlabs.com/video/{request_id}"
status_response = await client.get(
status_url,
headers={"X-API-Key": settings.topaz_api_key}
)
if status_response.status_code == 404:
# Fallback to /status if generic 404s (just in case)
# But user reported /status 404s.
logger.warning(f"Topaz Status Check 404 at {status_url}. Job might be lost or URL wrong.")
# Don't break immediately, maybe ephemeral?
pass
elif status_response.status_code != 200:
logger.warning(f"Topaz Status Check returned {status_response.status_code}: {status_response.text}")
continue
status_data = status_response.json()
status = status_data.get("status", "").lower()
if i % 10 == 0:
logger.info(f"Topaz Video Status: {status} Data: {status_data}")
if status == "completed" or status_data.get("outputUrl") or status_data.get("url"):
output_url = status_data.get("outputUrl") or status_data.get("url")
if output_url:
logger.info(f"Topaz Video API Success. Output URL: {output_url}")
video_response = await client.get(output_url)
upscaled_data = video_response.content
# Save output
base_name = os.path.splitext(input_asset.original_filename)[0]
clean_base_name = base_name.replace(" ", "_")
clean_model = model.replace(" ", "_")
filename = f"{clean_base_name}_{scale}X_{clean_model}.mp4"
storage_path = os.path.join(settings.storage_path, "videos")
os.makedirs(storage_path, exist_ok=True)
file_path = os.path.join(storage_path, filename)
with open(file_path, "wb") as f:
f.write(upscaled_data)
# Generate thumbnail
thumbnail_path = None
try:
from app.utils.video import generate_video_thumbnail
thumb_filename = f"{os.path.splitext(filename)[0]}_thumb.jpg"
thumb_path = os.path.join(storage_path, thumb_filename)
if generate_video_thumbnail(file_path, thumb_path, timestamp=1.0):
thumbnail_path = thumb_path
logger.info(f"Generated thumbnail for upscaled video: {thumb_path}")
except Exception as e:
logger.warning(f"Failed to generate thumbnail: {e}")
# Create output asset
output_asset = Asset(
user_id=job.user_id,
project_id=job.project_id,
original_filename=filename,
stored_filename=filename,
file_path=file_path,
thumbnail_path=thumbnail_path,
file_type="video",
mime_type="video/mp4",
file_size_bytes=len(upscaled_data),
width=output_width,
height=output_height,
duration_seconds=input_asset.duration_seconds,
source_module="video_upscaler",
source_job_id=job.id,
parent_asset_id=input_asset.id,
asset_metadata={
"scale": scale,
"model": model,
"frame_interpolation": frame_interpolation
}
)
db.add(output_asset)
db.commit()
db.refresh(output_asset)
job.output_asset_ids = [output_asset.id]
job.output_data = {"asset_id": str(output_asset.id), "file_path": file_path}
break
elif status == "failed":
raise ValueError(f"Video enhancement failed: {status_data.get('error')}")
except Exception as e:
logger.warning(f"Error checking status for job {job.id}: {e}")
# Continue polling
job.progress = min(50 + (i * 0.05), 95) # Slower progress for longer wait
db.commit()
if not output_asset:
raise TimeoutError("Video upscaling timed out or failed to return output")
job.progress = 100
job.status = "completed"
job.completed_at = datetime.utcnow()
db.commit()
except Exception as e:
job.status = "failed"
job.error_message = str(e)
db.commit()
finally:
db.close()