Spaces:
Runtime error
Runtime error
| import os | |
| import torch | |
| import gradio as gr | |
| import cv2 | |
| from PIL import Image | |
| from inference_engine import run_inference | |
| from motion_extractor import extract_pkl_from_video | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def full_pipeline(video_file, ref_image=None, width=512, height=512, steps=50, scale=3.0, seed=6666): | |
| # 1. 提取 motion pkl | |
| video_path = video_file.name | |
| motion_pkl_path = extract_pkl_from_video(video_path) | |
| gr.Info("⏳ Extract motion finished and begin animation...", visible=True) | |
| # 2. 处理参考图像(可选) | |
| if ref_image is not None: | |
| ref_path = "temp_ref.png" | |
| ref_image.save(ref_path) | |
| else: | |
| ref_path = "" | |
| # 3. 推理 | |
| output_path = run_inference( | |
| device, | |
| motion_pkl_path, | |
| ref_path, | |
| dst_width=width, | |
| dst_height=height, | |
| num_inference_steps=steps, | |
| guidance_scale=scale, | |
| seed=seed, | |
| ) | |
| return output_path | |
| def run_pipeline_with_feedback(video_file, ref_image, width, height, steps, scale, seed): | |
| try: | |
| if video_file is None: | |
| raise gr.Error("Please upload a dancing video (.mp4/.mov/.avi).") | |
| # 添加进度提示 | |
| gr.Info("⏳ Processing... Please wait several minutes.", visible=True) | |
| result = full_pipeline(video_file, ref_image, width, height, steps, scale, seed) | |
| gr.Info("✅ Inference done, please enjoy it!", visible=True) | |
| return result | |
| except Exception as e: | |
| traceback.print_exc() | |
| gr.Warning("⚠️ Inference failed: " + str(e)) | |
| return None | |
| # 构建 UI | |
| with gr.Blocks(title="MTVCrafter Inference Demo") as demo: | |
| gr.Markdown( | |
| """ | |
| # 🎨💃 MTVCrafter Inference Demo | |
| 💡 **Tip:** Upload a dancing video in **MP4/MOV/AVI** format, and optionally a reference image (e.g., PNG or JPG). | |
| This demo will extract human motion from the input video and animate the reference image accordingly. | |
| If no reference image is provided, the **first frame** of the video will be used as the reference. | |
| 🎞️ **Note:** The generated output video will contain exactly **49 frames**. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| video_input = gr.File(label="📹 Input Video (Required)", file_types=[".mp4", ".mov", ".avi"]) | |
| video_preview = gr.Video(label="👀 Preview of Uploaded Video", height=280) # 固定高度,避免对齐错位 | |
| def show_video_preview(video_file): | |
| return video_file.name if video_file else None | |
| video_input.change(fn=show_video_preview, inputs=video_input, outputs=video_preview) | |
| with gr.Column(scale=1): | |
| ref_image = gr.Image(type="pil", label="🖼️ Reference Image (Optional)", height=538) | |
| with gr.Accordion("⚙️ Advanced Settings", open=False): | |
| with gr.Row(): | |
| width = gr.Slider(384, 1024, value=512, step=16, label="Output Width") | |
| height = gr.Slider(384, 1024, value=512, step=16, label="Output Height") | |
| with gr.Row(): | |
| steps = gr.Slider(20, 100, value=50, step=5, label="Inference Steps") | |
| scale = gr.Slider(0.0, 10.0, value=3.0, step=0.25, label="Guidance Scale") | |
| seed = gr.Number(value=6666, label="Random Seed") | |
| with gr.Row(scale=1): | |
| output_video = gr.Video(label="🎬 Generated Video", interactive=False) | |
| run_btn = gr.Button("🚀 Run MTVCrafter", variant="primary") | |
| run_btn.click( | |
| fn=run_pipeline_with_feedback, | |
| inputs=[video_input, ref_image, width, height, steps, scale, seed], | |
| outputs=output_video, | |
| ) | |
| if __name__ == "__main__": | |
| os.environ["HF_ENDPOINT"] = "https://hf-mirror.com/" | |
| os.environ["NO_PROXY"] = "localhost,127.0.0.1/8,::1" | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) |