V-Express / app.py
faraday's picture
initial commit
9f49ca0
raw
history blame
4.62 kB
import gradio as gr
import git
import os
import shutil
import subprocess
import torchaudio
import torch
# Clone the V-Express repository if not already cloned
repo_url = "https://github.com/tencent-ailab/V-Express"
repo_dir = "V-Express"
hf_model_repo_url = "https://huggingface.co/tk93/V-Express"
hf_model_repo_dir = "V-Express-models"
output_dir = "output"
temp_audio_path = "temp.mp3"
if not os.path.exists(repo_dir):
git.Repo.clone_from(repo_url, repo_dir)
# Install Git LFS and clone the HuggingFace model repository
def setup_models():
subprocess.run(["git", "lfs", "install"], check=True)
if not os.path.exists(hf_model_repo_dir):
git.Repo.clone_from(hf_model_repo_url, hf_model_repo_dir)
# Move the model_ckpts directory to the correct location
src = os.path.join(hf_model_repo_dir, "model_ckpts")
dst = os.path.join(repo_dir, "model_ckpts")
if os.path.exists(src):
if os.path.exists(dst):
shutil.rmtree(dst)
shutil.move(src, dst)
setup_models()
result_path = os.path.join(repo_dir, output_dir)
if not os.path.exists(result_path):
os.mkdir(result_path)
os.chdir(repo_dir)
# Function to run V-Express demo
def run_demo(
reference_image, audio, video,
kps_path, output_path, retarget_strategy,
reference_attention_weight=0.95,
audio_attention_weight=3.0,
progress=gr.Progress()):
# Step 1: Extract Keypoints from Video
progress((0,100), desc="Starting...")
kps_sequence_save_path = f"./{output_dir}/kps.pth"
if video is not None:
# Run the script to extract keypoints and audio from the video
progress((25,100), desc="Extract keypoints and audio...")
audio_path = video.replace(".mp4", ".mp3")
subprocess.run([
"python",
"scripts/extract_kps_sequence_and_audio.py",
"--video_path", video,
"--kps_sequence_save_path", kps_sequence_save_path,
"--audio_save_path", audio_path
], check=True)
progress((50,100), desc="Keypoints and audio extracted successfully.")
#return "Keypoints and audio extracted successfully."
rem_progress = (75,100)
else:
rem_progress = (50,100)
audio_path = audio
shutil.copy(kps_path, kps_sequence_save_path)
subprocess.run(["ffmpeg", "-i", audio_path, "-c:v", "libx264", "-crf", "18", "-preset", "slow", temp_audio_path])
shutil.move(temp_audio_path, audio_path)
# Step 2: Run Inference with Reference Image and Audio
# Determine the inference script and parameters based on the selected retargeting strategy
progress(rem_progress, desc="Inference...")
inference_script = "inference.py"
inference_params = [
"--reference_image_path", reference_image,
"--audio_path", audio_path,
"--kps_path", kps_sequence_save_path,
"--output_path", output_path,
"--retarget_strategy", retarget_strategy,
"--num_inference_steps", "30", # Hardcoded for now, can be adjusted
"--reference_attention_weight", str(reference_attention_weight),
"--audio_attention_weight", str(audio_attention_weight)
]
# Run the inference script with the provided parameters
subprocess.run(["python", inference_script] + inference_params, check=True)
status = f"Video generated successfully. Saved at: {output_path}"
progress((100,100), desc=status)
return output_path, kps_path
# Create Gradio interface
inputs = [
gr.Image(label="Reference Image", type="filepath"),
gr.Audio(label="Audio", type="filepath"),
gr.Video(label="Video"),
gr.File(label="KPS sequences", value=f"test_samples/short_case/10/kps.pth"),
gr.Textbox(label="Output Path for generated video", value=f"./{output_dir}/output_video.mp4"),
gr.Dropdown(label="Retargeting Strategy", choices=["no_retarget", "fix_face", "offset_retarget", "naive_retarget"], value="no_retarget"),
gr.Slider(label="Reference Attention Weight", minimum=0.0, maximum=1.0, step=0.01, value=0.95),
gr.Slider(label="Audio Attention Weight", minimum=1.0, maximum=3.0, step=0.1, value=3.0)
]
output = [
gr.Video(label="Generated Video"),
gr.File(label="Generated KPS Sequences File (kps.pth)")
]
# Title and description for the interface
title = "V-Express Gradio Interface"
description = "An interactive interface for generating talking face videos using V-Express."
# Launch Gradio app
demo = gr.Interface(run_demo, inputs, output, title=title, description=description)
demo.queue().launch()