gbc-backup / content /hedra_gradio.py
sicer's picture
Initial commit from existing repo
e9fa53a
import asyncio
import json
import mimetypes
import os
import re
import uuid
from typing import Tuple
import aiohttp
import gradio as gr
from PIL import Image
def get_ext(url):
rule = r"\.(.*?)\?"
rst = re.findall(rule, url)[0]
return rst.split(".")[-1]
async def download_file(url, local_filename):
async with aiohttp.ClientSession() as session:
async with session.get(url) as response:
ext = get_ext(url)
if response.status == 200:
filename_with_ext = os.path.abspath(f"{local_filename}.{ext}")
content = await response.read()
with open(filename_with_ext, "wb") as f:
f.write(content)
if ext == "webp":
im = Image.open(filename_with_ext).convert("RGB")
im.save(f"{local_filename}.jpg", "jpeg")
os.remove(filename_with_ext)
return f"{local_filename}.jpg"
else:
return filename_with_ext
else:
raise RuntimeError(f"{url} download failed")
class HedraClient:
def __init__(self):
self._base_url = "https://mercury.dev.dream-ai.com/api"
self._check_task_url = "https://mercury.dev.dream-ai.com/api/v1/projects/{task_id}"
self._key = "sk_hedra-TxkxBe8htuAuGXwoPYgjHhYpwcQ3gdFmcGdRTLksRKUcSQEpm7VCNzSNj2680fZC"
self.timeout = aiohttp.ClientTimeout(total=10)
os.makedirs("temp", exist_ok=True)
async def post_audio(self, audio_url):
headers = {
"X-API-KEY": self._key,
}
local_audio = await download_file(audio_url, f"temp/{str(uuid.uuid4())}")
try:
async with aiohttp.ClientSession() as session:
data = aiohttp.FormData()
data.add_field("file", open(local_audio, "rb"))
async with session.post(
f"{self._base_url}/v1/audio", headers=headers, data={"file": open(local_audio, "rb")}
) as resp:
return await resp.json()
finally:
if os.path.exists(local_audio):
os.remove(local_audio)
async def post_image(self, image_url):
headers = {
"X-API-KEY": self._key,
}
local_image = await download_file(image_url, f"temp/{str(uuid.uuid4())}")
try:
async with aiohttp.ClientSession() as session:
data = aiohttp.FormData()
data.add_field("file", open(local_image, "rb"))
async with session.post(
f"{self._base_url}/v1/portrait", headers=headers, data={"file": open(local_image, "rb")}, timeout=10
) as resp:
return await resp.json()
finally:
if os.path.exists(local_image):
os.remove(local_image)
async def submit_task(self, audio_url: str, image_url: str, aspect_ratio: str) -> Tuple[str, str]:
headers = {
"X-API-KEY": self._key,
}
audio_task = asyncio.create_task(self.post_audio(audio_url))
image_task = asyncio.create_task(self.post_image(image_url))
audio_result, image_result = await asyncio.gather(audio_task, image_task)
payload = {
"voiceUrl": audio_result["url"],
"avatarImage": image_result["url"],
"aspectRatio": aspect_ratio,
}
async with aiohttp.ClientSession(headers=headers, timeout=self.timeout) as session:
async with session.post(f"{self._base_url}/v1/characters", json=payload) as response:
data = await response.json()
task_id = data.get("jobId", None)
assert task_id is not None, f"Failed to submit task, {data}"
request_id = data.get("request_id", None)
return task_id, request_id
async def get_response(self, task_id: str) -> Tuple[str, float]:
headers = {
"X-API-KEY": self._key,
}
async with aiohttp.ClientSession(headers=headers, timeout=self.timeout) as session:
while True:
async with session.get(self._check_task_url.format(task_id=task_id)) as response:
data = await response.json()
status = data.get("status", None)
if status == "Completed":
video_url = data.get("videoUrl", None)
assert video_url is not None, f"Failed to get video_url from response[{data}]"
video_duration = 4
return video_url, video_duration
elif status in ["Failed"] or status is None:
raise RuntimeError(
f"Task {task_id} failed or was canceled. {data.get('output', {}).get('message', '')}"
)
else:
await asyncio.sleep(4)
def gradio_interface(audio_url, image_url, aspect_ratio):
client = HedraClient()
async def process(audio_url, image_url, aspect_ratio):
task_id, request_id = await client.submit_task(audio_url, image_url, aspect_ratio)
video_url, video_duration = await client.get_response(task_id)
return video_url, video_duration
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
video_url, video_duration = loop.run_until_complete(process(audio_url, image_url, aspect_ratio))
return video_url, video_duration
iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.inputs.Textbox(label="Audio URL"),
gr.inputs.Textbox(label="Image URL"),
gr.inputs.Textbox(label="Aspect Ratio"),
],
outputs=[
gr.outputs.Textbox(label="Video URL"),
gr.outputs.Textbox(label="Video Duration"),
],
title="Hedra Gradio Interface",
description="Submit audio and image URLs to generate a video.",
)
iface.launch(share=True)