|
import gradio as gr |
|
import git |
|
import os |
|
import shutil |
|
import subprocess |
|
import torchaudio |
|
import torch |
|
|
|
|
|
repo_url = "https://github.com/tencent-ailab/V-Express" |
|
repo_dir = "V-Express" |
|
|
|
hf_model_repo_url = "https://huggingface.co/tk93/V-Express" |
|
hf_model_repo_dir = "V-Express-models" |
|
|
|
output_dir = "output" |
|
temp_audio_path = "temp.mp3" |
|
|
|
if not os.path.exists(repo_dir): |
|
git.Repo.clone_from(repo_url, repo_dir) |
|
|
|
|
|
def setup_models(): |
|
subprocess.run(["git", "lfs", "install"], check=True) |
|
|
|
if not os.path.exists(hf_model_repo_dir): |
|
git.Repo.clone_from(hf_model_repo_url, hf_model_repo_dir) |
|
|
|
|
|
src = os.path.join(hf_model_repo_dir, "model_ckpts") |
|
dst = os.path.join(repo_dir, "model_ckpts") |
|
if os.path.exists(src): |
|
if os.path.exists(dst): |
|
shutil.rmtree(dst) |
|
shutil.move(src, dst) |
|
|
|
|
|
setup_models() |
|
|
|
result_path = os.path.join(repo_dir, output_dir) |
|
if not os.path.exists(result_path): |
|
os.mkdir(result_path) |
|
|
|
os.chdir(repo_dir) |
|
|
|
|
|
|
|
def run_demo( |
|
reference_image, audio, video, |
|
kps_path, output_path, retarget_strategy, |
|
reference_attention_weight=0.95, |
|
audio_attention_weight=3.0, |
|
progress=gr.Progress()): |
|
|
|
progress((0,100), desc="Starting...") |
|
|
|
kps_sequence_save_path = f"./{output_dir}/kps.pth" |
|
|
|
if video is not None: |
|
|
|
progress((25,100), desc="Extract keypoints and audio...") |
|
audio_path = video.replace(".mp4", ".mp3") |
|
|
|
subprocess.run([ |
|
"python", |
|
"scripts/extract_kps_sequence_and_audio.py", |
|
"--video_path", video, |
|
"--kps_sequence_save_path", kps_sequence_save_path, |
|
"--audio_save_path", audio_path |
|
], check=True) |
|
progress((50,100), desc="Keypoints and audio extracted successfully.") |
|
|
|
rem_progress = (75,100) |
|
else: |
|
rem_progress = (50,100) |
|
audio_path = audio |
|
shutil.copy(kps_path, kps_sequence_save_path) |
|
|
|
subprocess.run(["ffmpeg", "-i", audio_path, "-c:v", "libx264", "-crf", "18", "-preset", "slow", temp_audio_path]) |
|
shutil.move(temp_audio_path, audio_path) |
|
|
|
|
|
|
|
progress(rem_progress, desc="Inference...") |
|
inference_script = "inference.py" |
|
inference_params = [ |
|
"--reference_image_path", reference_image, |
|
"--audio_path", audio_path, |
|
"--kps_path", kps_sequence_save_path, |
|
"--output_path", output_path, |
|
"--retarget_strategy", retarget_strategy, |
|
"--num_inference_steps", "30", |
|
"--reference_attention_weight", str(reference_attention_weight), |
|
"--audio_attention_weight", str(audio_attention_weight) |
|
] |
|
|
|
|
|
subprocess.run(["python", inference_script] + inference_params, check=True) |
|
status = f"Video generated successfully. Saved at: {output_path}" |
|
progress((100,100), desc=status) |
|
return output_path, kps_path |
|
|
|
|
|
inputs = [ |
|
gr.Image(label="Reference Image", type="filepath"), |
|
gr.Audio(label="Audio", type="filepath"), |
|
gr.Video(label="Video"), |
|
gr.File(label="KPS sequences", value=f"test_samples/short_case/10/kps.pth"), |
|
gr.Textbox(label="Output Path for generated video", value=f"./{output_dir}/output_video.mp4"), |
|
gr.Dropdown(label="Retargeting Strategy", choices=["no_retarget", "fix_face", "offset_retarget", "naive_retarget"], value="no_retarget"), |
|
gr.Slider(label="Reference Attention Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.95), |
|
gr.Slider(label="Audio Attention Weight", minimum=1.0, maximum=3.0, step=0.1, value=3.0) |
|
] |
|
|
|
output = [ |
|
gr.Video(label="Generated Video"), |
|
gr.File(label="Generated KPS Sequences File (kps.pth)") |
|
] |
|
|
|
|
|
title = "V-Express Gradio Interface" |
|
description = "An interactive interface for generating talking face videos using V-Express." |
|
|
|
|
|
demo = gr.Interface(run_demo, inputs, output, title=title, description=description) |
|
demo.queue().launch() |