110 lines
4.8 KiB
Python
110 lines
4.8 KiB
Python
import asyncio
|
|
import sys
|
|
import os
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
sys.path.append(os.getcwd())
|
|
|
|
from app.services import video_upscaler
|
|
from app.models.job import Job
|
|
from app.models.asset import Asset
|
|
from app.config import settings
|
|
|
|
# Mock objects to simulate DB and Asset
|
|
mock_db = MagicMock()
|
|
mock_job = MagicMock()
|
|
mock_asset = MagicMock()
|
|
|
|
mock_job.id = "test_job_123"
|
|
mock_job.input_asset_ids = ["asset_123"]
|
|
mock_job.input_data = {
|
|
"scale": 2,
|
|
"model": "Proteus", # Test mapping to prob-4
|
|
"fps": 60.0,
|
|
"sharpening": 50,
|
|
"recover_detail": 30,
|
|
"face_enhancement": False # Disable to test model mapping
|
|
}
|
|
|
|
mock_asset.id = "asset_123"
|
|
mock_asset.file_path = "test_video.mp4"
|
|
mock_asset.duration_seconds = 10.0
|
|
mock_asset.width = 1920
|
|
mock_asset.height = 1080
|
|
mock_asset.file_size_bytes = 1024 * 1024 * 10
|
|
|
|
# Mock DB queries
|
|
mock_db.query.return_value.filter.return_value.first.side_effect = [mock_job, mock_asset]
|
|
|
|
async def test_topaz_payload():
|
|
print("Testing Topaz Video Upscaling Payload Construction...")
|
|
|
|
# Check if Topaz API Key is set
|
|
if not settings.topaz_api_key:
|
|
print(" [SKIP] Topaz API Key not set (skipping actual API call)")
|
|
return
|
|
|
|
# We want to verify the parameters passed to client.post
|
|
# expecting: 'https://api.topazlabs.com/video/'
|
|
|
|
with patch("app.services.video_upscaler.SessionLocal", return_value=mock_db):
|
|
with patch("app.services.video_upscaler.httpx.AsyncClient") as MockClient:
|
|
mock_client_instance = MockClient.return_value.__aenter__.return_value
|
|
# Mock the post response to avoid actual API call failure
|
|
mock_client_instance.post.return_value.status_code = 200
|
|
mock_client_instance.post.return_value.json.return_value = {"requestId": "mock_req_id"}
|
|
|
|
# Mock extract_video_metadata to avoid FFmpeg dependency if missing
|
|
with patch("app.utils.video.extract_video_metadata", return_value={"duration_seconds": 10, "fps": 30, "width": 1920, "height": 1080}):
|
|
try:
|
|
# We only care about the initial POST to /video/
|
|
# The function continues to wait for upload URLs, etc.
|
|
# We can mock that too or expect it to fail later.
|
|
# Let's mock the subsequent calls to let it proceed slightly or catch the call.
|
|
|
|
mock_client_instance.patch.return_value.json.return_value = {"urls": []}
|
|
# We'll likely error out at file reading or upload loop, but we can inspect the POST call before that.
|
|
|
|
# Run the function (it will fail on file read likely)
|
|
try:
|
|
await video_upscaler.upscale("test_job_123")
|
|
except Exception as e:
|
|
print(f" [INFO] Execution stopped as expected: {e}")
|
|
|
|
# VERIFY POST CALL
|
|
# assert mock_client_instance.post.called
|
|
call_args = mock_client_instance.post.call_args
|
|
if call_args:
|
|
url, kwargs = call_args
|
|
if url[0] == "https://api.topazlabs.com/video/":
|
|
print(" [SUCCESS] API Endpoint Correct")
|
|
payload = kwargs.get("json", {})
|
|
filters = payload.get("filters", [])
|
|
output = payload.get("output", {})
|
|
|
|
print(f" Filters Sent: {filters}")
|
|
print(f" Output FrameRate: {output.get('frameRate')}")
|
|
|
|
# Verify new parameters
|
|
if len(filters) > 0:
|
|
f = filters[0]
|
|
if f.get("details") == 50 and f.get("recoverOriginalDetailValue") == 30 and f.get("model") == "prob-4":
|
|
print(" [VERIFIED] Parameters (details, recoverOriginalDetailValue, Model=prob-4) correctly mapped from Proteus!")
|
|
else:
|
|
print(f" [FAILED] Parameter mapping incorrect. Got: {f}")
|
|
|
|
if output.get("frameRate") == 60.0:
|
|
print(" [VERIFIED] FPS correctly mapped!")
|
|
else:
|
|
print(f" [FAILED] FPS incorrect. Got: {output.get('frameRate')}")
|
|
|
|
else:
|
|
print(f" [FAILED] different URL called: {url}")
|
|
else:
|
|
print(" [FAILED] POST not called")
|
|
|
|
except Exception as e:
|
|
print(f" [ERROR] {e}")
|
|
|
|
if __name__ == "__main__":
|
|
asyncio.run(test_topaz_payload())
|