|
import asyncio |
|
import json |
|
import mimetypes |
|
import os |
|
import re |
|
import uuid |
|
from typing import Tuple |
|
|
|
import aiohttp |
|
import gradio as gr |
|
from PIL import Image |
|
|
|
|
|
def get_ext(url): |
|
rule = r"\.(.*?)\?" |
|
rst = re.findall(rule, url)[0] |
|
|
|
return rst.split(".")[-1] |
|
|
|
|
|
async def download_file(url, local_filename): |
|
async with aiohttp.ClientSession() as session: |
|
async with session.get(url) as response: |
|
ext = get_ext(url) |
|
|
|
if response.status == 200: |
|
filename_with_ext = os.path.abspath(f"{local_filename}.{ext}") |
|
content = await response.read() |
|
|
|
with open(filename_with_ext, "wb") as f: |
|
f.write(content) |
|
|
|
if ext == "webp": |
|
im = Image.open(filename_with_ext).convert("RGB") |
|
im.save(f"{local_filename}.jpg", "jpeg") |
|
os.remove(filename_with_ext) |
|
return f"{local_filename}.jpg" |
|
else: |
|
return filename_with_ext |
|
else: |
|
raise RuntimeError(f"{url} download failed") |
|
|
|
|
|
class HedraClient: |
|
def __init__(self): |
|
self._base_url = "https://mercury.dev.dream-ai.com/api" |
|
self._check_task_url = "https://mercury.dev.dream-ai.com/api/v1/projects/{task_id}" |
|
self._key = "sk_hedra-TxkxBe8htuAuGXwoPYgjHhYpwcQ3gdFmcGdRTLksRKUcSQEpm7VCNzSNj2680fZC" |
|
self.timeout = aiohttp.ClientTimeout(total=10) |
|
os.makedirs("temp", exist_ok=True) |
|
|
|
async def post_audio(self, audio_url): |
|
headers = { |
|
"X-API-KEY": self._key, |
|
} |
|
local_audio = await download_file(audio_url, f"temp/{str(uuid.uuid4())}") |
|
try: |
|
async with aiohttp.ClientSession() as session: |
|
data = aiohttp.FormData() |
|
data.add_field("file", open(local_audio, "rb")) |
|
async with session.post( |
|
f"{self._base_url}/v1/audio", headers=headers, data={"file": open(local_audio, "rb")} |
|
) as resp: |
|
return await resp.json() |
|
finally: |
|
if os.path.exists(local_audio): |
|
os.remove(local_audio) |
|
|
|
async def post_image(self, image_url): |
|
headers = { |
|
"X-API-KEY": self._key, |
|
} |
|
local_image = await download_file(image_url, f"temp/{str(uuid.uuid4())}") |
|
try: |
|
async with aiohttp.ClientSession() as session: |
|
data = aiohttp.FormData() |
|
data.add_field("file", open(local_image, "rb")) |
|
async with session.post( |
|
f"{self._base_url}/v1/portrait", headers=headers, data={"file": open(local_image, "rb")}, timeout=10 |
|
) as resp: |
|
return await resp.json() |
|
finally: |
|
if os.path.exists(local_image): |
|
os.remove(local_image) |
|
|
|
async def submit_task(self, audio_url: str, image_url: str, aspect_ratio: str) -> Tuple[str, str]: |
|
headers = { |
|
"X-API-KEY": self._key, |
|
} |
|
|
|
audio_task = asyncio.create_task(self.post_audio(audio_url)) |
|
image_task = asyncio.create_task(self.post_image(image_url)) |
|
audio_result, image_result = await asyncio.gather(audio_task, image_task) |
|
payload = { |
|
"voiceUrl": audio_result["url"], |
|
"avatarImage": image_result["url"], |
|
"aspectRatio": aspect_ratio, |
|
} |
|
|
|
async with aiohttp.ClientSession(headers=headers, timeout=self.timeout) as session: |
|
async with session.post(f"{self._base_url}/v1/characters", json=payload) as response: |
|
data = await response.json() |
|
task_id = data.get("jobId", None) |
|
assert task_id is not None, f"Failed to submit task, {data}" |
|
request_id = data.get("request_id", None) |
|
return task_id, request_id |
|
|
|
async def get_response(self, task_id: str) -> Tuple[str, float]: |
|
headers = { |
|
"X-API-KEY": self._key, |
|
} |
|
async with aiohttp.ClientSession(headers=headers, timeout=self.timeout) as session: |
|
while True: |
|
async with session.get(self._check_task_url.format(task_id=task_id)) as response: |
|
data = await response.json() |
|
status = data.get("status", None) |
|
if status == "Completed": |
|
video_url = data.get("videoUrl", None) |
|
assert video_url is not None, f"Failed to get video_url from response[{data}]" |
|
video_duration = 4 |
|
return video_url, video_duration |
|
elif status in ["Failed"] or status is None: |
|
raise RuntimeError( |
|
f"Task {task_id} failed or was canceled. {data.get('output', {}).get('message', '')}" |
|
) |
|
else: |
|
await asyncio.sleep(4) |
|
|
|
|
|
def gradio_interface(audio_url, image_url, aspect_ratio): |
|
client = HedraClient() |
|
|
|
async def process(audio_url, image_url, aspect_ratio): |
|
task_id, request_id = await client.submit_task(audio_url, image_url, aspect_ratio) |
|
video_url, video_duration = await client.get_response(task_id) |
|
return video_url, video_duration |
|
|
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
video_url, video_duration = loop.run_until_complete(process(audio_url, image_url, aspect_ratio)) |
|
return video_url, video_duration |
|
|
|
|
|
iface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=[ |
|
gr.inputs.Textbox(label="Audio URL"), |
|
gr.inputs.Textbox(label="Image URL"), |
|
gr.inputs.Textbox(label="Aspect Ratio"), |
|
], |
|
outputs=[ |
|
gr.outputs.Textbox(label="Video URL"), |
|
gr.outputs.Textbox(label="Video Duration"), |
|
], |
|
title="Hedra Gradio Interface", |
|
description="Submit audio and image URLs to generate a video.", |
|
) |
|
|
|
iface.launch(share=True) |
|
|