Wan2.2-S2V-14B / app.py
immango's picture
Create app.py
637d3ba verified
import gradio as gr
import subprocess
import tempfile
import os
import sys
import shutil
from pathlib import Path
import time
# 输出 Gradio 版本信息
print(f"===== Application Startup at {time.strftime('%Y-%m-%d %H:%M:%S')} =====")
print(f"Gradio version: {gr.__version__}")
print(f"Python version: {sys.version}")
print(f"Python executable: {sys.executable}")
print("=" * 60)
class Wan2S2VPipeline:
def __init__(self):
self.model_loaded = False
self.model_path = None
self.script_path = None
self.ckpt_dir = None
self.model_repo = "Wan-AI/Wan2.2-S2V-14B"
def load_model(self):
"""下载Wan2.2-S2V-14B模型和脚本"""
try:
if self.model_loaded:
return True, "模型已加载"
# 设置工作目录(使用持久目录)
work_dir = "/tmp/wan2.2"
os.makedirs(work_dir, exist_ok=True)
# 步骤1: 克隆官方代码仓库
print("步骤1: 克隆官方代码仓库...")
repo_path = os.path.join(work_dir, "Wan2.2")
if not os.path.exists(os.path.join(repo_path, ".git")):
# 如果目录不存在或不是git仓库,则克隆
if os.path.exists(repo_path):
shutil.rmtree(repo_path)
result = subprocess.run(
["git", "clone", "https://github.com/Wan-Video/Wan2.2.git", repo_path],
capture_output=True,
text=True,
timeout=300
)
if result.returncode != 0:
return False, f"❌ 克隆代码仓库失败: {result.stderr}"
print("✅ 代码仓库克隆成功")
else:
print("✅ 代码仓库已存在,跳过克隆")
# 步骤2: 下载模型权重
print("步骤2: 下载模型权重...")
model_dir = os.path.join(work_dir, "Wan2.2-S2V-14B")
if not os.path.exists(model_dir):
from huggingface_hub import snapshot_download
print(f"正在下载模型 {self.model_repo}...")
model_path = snapshot_download(
repo_id=self.model_repo,
cache_dir="/tmp/hf_cache",
local_dir=model_dir,
local_dir_use_symlinks=False
)
print(f"✅ 模型权重下载完成: {model_path}")
else:
print("✅ 模型权重已存在,跳过下载")
# 步骤3: 安装依赖
print("步骤3: 安装依赖...")
requirements_file = os.path.join(repo_path, "requirements.txt")
if os.path.exists(requirements_file):
try:
result = subprocess.run(
[sys.executable, "-m", "pip", "install", "-r", requirements_file],
capture_output=True,
text=True,
timeout=600,
cwd=repo_path
)
if result.returncode == 0:
print("✅ 依赖安装成功")
else:
print(f"⚠️ 依赖安装警告: {result.stderr}")
except Exception as e:
print(f"⚠️ 依赖安装跳过: {e}")
else:
print("⚠️ 未找到 requirements.txt,跳过依赖安装")
# 步骤4: 设置路径
self.model_path = repo_path
self.script_path = os.path.join(repo_path, "generate.py")
self.ckpt_dir = model_dir
# 验证文件
if not os.path.exists(self.script_path):
return False, "❌ 未找到 generate.py 脚本"
if not os.path.exists(self.ckpt_dir):
return False, "❌ 未找到模型权重目录"
self.model_loaded = True
print("🎉 Wan2.2-S2V-14B 模型准备完成!")
return True, "✅ 模型加载成功!"
except Exception as e:
error_msg = f"模型加载失败: {str(e)}"
print(error_msg)
return False, f"❌ {error_msg}"
def generate(self, task, size, prompt, image_file, audio_file,
num_frames=16, guidance_scale=7.5,
num_inference_steps=20, seed=-1, offload_model=True,
convert_model_dtype=True):
"""执行Wan2.2-S2V-14B生成命令"""
try:
if not self.model_loaded:
success, message = self.load_model()
if not success:
return None, message
# 设置环境变量解决 OMP_NUM_THREADS 问题
env = os.environ.copy()
env["OMP_NUM_THREADS"] = "1"
env["TOKENIZERS_PARALLELISM"] = "false"
# 验证必需参数
if not prompt or not prompt.strip():
return None, "❌ 提示词不能为空"
if not image_file:
return None, "❌ 请上传输入图片"
if not audio_file:
return None, "❌ 请上传输入音频"
# 构建命令行参数
cmd = [sys.executable, self.script_path]
# 必需参数
cmd.extend(["--task", task])
cmd.extend(["--size", size])
cmd.extend(["--ckpt_dir", self.ckpt_dir])
cmd.extend(["--prompt", prompt])
cmd.extend(["--image", image_file])
cmd.extend(["--audio", audio_file])
# 可选参数
if num_frames is not None:
cmd.extend(["--frame_num", str(num_frames)])
# 使用 infer_frames 替代 fps 参数
cmd.extend(["--infer_frames", str(num_frames)])
if guidance_scale is not None:
cmd.extend(["--sample_guide_scale", str(guidance_scale)])
if num_inference_steps is not None:
cmd.extend(["--sample_steps", str(num_inference_steps)])
if seed is not None and seed != -1:
cmd.extend(["--base_seed", str(seed)])
# 模型优化参数
if offload_model:
cmd.extend(["--offload_model", "True"])
else:
cmd.extend(["--offload_model", "False"])
if convert_model_dtype:
cmd.append("--convert_model_dtype")
print(f"执行命令: {' '.join(cmd)}")
# 创建临时输出目录
output_dir = os.path.join(self.model_path, "outputs")
os.makedirs(output_dir, exist_ok=True)
# 执行命令(实时输出日志)
start_time = time.time()
print("🚀 开始执行 generate.py 脚本...")
print("=" * 50)
# 使用 Popen 实现实时日志输出
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT, # 将 stderr 重定向到 stdout
text=True,
bufsize=1, # 行缓冲
cwd=self.model_path,
env=env
)
# 实时读取输出(带超时检查)
all_output = []
start_read_time = time.time()
timeout_seconds = 3600 # 10分钟超时
while True:
# 检查是否超时
if time.time() - start_read_time > timeout_seconds:
process.terminate() # 尝试优雅终止
try:
process.wait(timeout=10) # 等待10秒
except subprocess.TimeoutExpired:
process.kill() # 强制终止
raise subprocess.TimeoutExpired(cmd, timeout_seconds)
# 尝试读取输出(非阻塞)
output_line = process.stdout.readline()
if output_line == '' and process.poll() is not None:
break
if output_line:
output_line = output_line.strip()
if output_line: # 忽略空行
print(f"[generate.py] {output_line}")
all_output.append(output_line)
# 重置超时计时器(有输出说明脚本还在运行)
start_read_time = time.time()
# 等待进程完成
return_code = process.wait()
execution_time = time.time() - start_time
print("=" * 50)
print(f"脚本执行完成,返回码: {return_code}")
print(f"总耗时: {execution_time:.1f}秒")
if return_code == 0:
print("✅ 命令执行成功")
# 构建详细的成功消息
success_msg = f"✅ 生成成功!耗时: {execution_time:.1f}秒\n\n"
if all_output:
success_msg += f"脚本输出:\n" + "\n".join(all_output) + "\n"
# 查找输出文件
output_files = self._find_output_files()
if output_files:
# 直接返回原始输出文件路径
output_file = output_files[0]
print(f"找到输出文件: {output_file}")
return output_file, success_msg
else:
return None, f"⚠️ 生成成功但未找到输出文件\n\n脚本输出:\n" + "\n".join(all_output)
else:
# 构建详细的错误消息
error_msg = f"脚本执行失败,返回码: {return_code}\n\n"
if all_output:
error_msg += f"脚本输出:\n" + "\n".join(all_output)
else:
error_msg += "无输出信息"
print(f"❌ 命令执行失败: {error_msg}")
return None, f"❌ 生成失败:\n{error_msg}"
except subprocess.TimeoutExpired:
return None, "⏰ 生成超时(10分钟),请尝试减少参数或检查模型状态"
except Exception as e:
error_msg = f"执行失败: {str(e)}"
print(error_msg)
return None, f"❌ {error_msg}"
def _find_output_files(self):
"""查找输出文件"""
output_extensions = ['.mp4', '.gif', '.avi', '.mov', '.png', '.jpg', '.jpeg']
output_files = []
# 优先搜索 outputs 目录
outputs_dir = os.path.join(self.model_path, "outputs")
if os.path.exists(outputs_dir):
for ext in output_extensions:
for file_path in Path(outputs_dir).rglob(f"*{ext}"):
if file_path.is_file():
output_files.append(str(file_path))
print(f"在 outputs 目录找到文件: {file_path}")
# 如果没有找到,搜索整个模型目录
if not output_files:
print("在 outputs 目录未找到文件,搜索整个模型目录...")
for ext in output_extensions:
for file_path in Path(self.model_path).rglob(f"*{ext}"):
if file_path.is_file():
# 排除一些不需要的文件
file_path_str = str(file_path)
if not any(exclude in file_path_str.lower() for exclude in ['.git', '__pycache__', 'node_modules']):
output_files.append(file_path_str)
print(f"在模型目录找到文件: {file_path_str}")
# 按修改时间排序,最新的文件在前面
if output_files:
output_files.sort(key=lambda x: os.path.getmtime(x), reverse=True)
print(f"找到 {len(output_files)} 个输出文件,按时间排序")
return output_files
def _copy_output_for_display(self, output_file):
"""复制输出文件到临时目录以便Gradio显示(已弃用)"""
# 此方法已不再使用,直接返回原始文件路径
print(f"直接使用原始文件: {output_file}")
return output_file
# 创建全局实例
pipeline = Wan2S2VPipeline()
def generate_interface(task, size, prompt, image_file, audio_file,
num_frames, guidance_scale, num_inference_steps,
seed, offload_model, convert_model_dtype):
"""Gradio 界面函数"""
# 执行生成
result, message = pipeline.generate(
task=task,
size=size,
prompt=prompt,
image_file=image_file,
audio_file=audio_file,
num_frames=num_frames,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
seed=seed,
offload_model=offload_model,
convert_model_dtype=convert_model_dtype
)
return result, message
def load_model_interface():
"""加载模型界面函数"""
success, message = pipeline.load_model()
return message
# 创建 Gradio 界面
with gr.Blocks(title="Wan2.2-S2V-14B 视频生成器") as demo:
gr.Markdown("""
# 使用前说明:本项目无法正常运行是因为没有选择GPU部署
# 完整的运行,请参考工程Files或者复制这个space,部署时最低选择 Nvidia 1xL40S 48G VRAM
# 🎬 Wan2.2-S2V-14B 视频生成器
**模型介绍**: Wan2.2-S2V-14B 是一个强大的图像到视频生成模型,支持音频引导。
**使用方法**:
1. 点击"🚀 加载模型"按钮下载模型
2. 填写提示词、上传图片和音频
3. 调整参数后点击"🎬 开始生成"
**注意**: 首次使用需要下载约14GB的模型文件,请耐心等待。
""")
with gr.Row():
with gr.Column(scale=1):
# 模型加载
gr.Markdown("### 📥 模型管理")
load_btn = gr.Button("🚀 加载模型", variant="primary", size="lg")
load_status = gr.Textbox(label="模型状态", interactive=False, value="等待加载模型...")
# 必需参数
gr.Markdown("### 📝 必需参数")
task = gr.Textbox(
label="任务类型",
value="s2v-14B",
interactive=False
)
size = gr.Dropdown(
label="分辨率",
choices=["1024*704", "1024*1024", "704*1024", "512*512"],
value="1024*704"
)
prompt = gr.Textbox(
label="提示词 *",
lines=3,
placeholder="例如: Summer beach vacation style, a white cat wearing sunglasses sits on a surfboard."
)
image = gr.Image(
label="输入图片 *",
type="filepath"
)
audio = gr.Audio(
label="输入音频 *",
type="filepath"
)
# 高级参数
with gr.Accordion("🔧 高级参数", open=False):
num_frames = gr.Slider(
8, 32, 16,
step=1,
label="帧数 (frame_num/infer_frames)"
)
guidance_scale = gr.Slider(
1.0, 20.0, 7.5,
step=0.1,
label="引导强度 (sample_guide_scale)"
)
num_inference_steps = gr.Slider(
10, 100, 20,
step=1,
label="推理步数 (sample_steps)"
)
seed = gr.Number(
label="随机种子 (base_seed)",
value=-1
)
with gr.Row():
offload_model = gr.Checkbox(
label="模型卸载",
value=True
)
convert_model_dtype = gr.Checkbox(
label="转换数据类型",
value=True
)
# 生成按钮
generate_btn = gr.Button("🎬 开始生成", variant="primary", size="lg")
with gr.Column(scale=1):
# 输出结果
gr.Markdown("### 🎥 生成结果")
output = gr.File(label="输出视频")
status = gr.Textbox(label="生成状态", interactive=False, lines=3)
# 使用说明
gr.Markdown("""
### 📋 使用说明
**参数说明**:
- **分辨率**: 选择适合你需求的视频尺寸
- **提示词**: 用英文描述想要的视频内容,越详细越好
- **图片**: 上传参考图片,模型会基于此生成视频
- **音频**: 上传音频文件,模型会结合音频内容生成视频
**高级参数**:
- **帧数 (frame_num/infer_frames)**: 控制视频长度,8-32帧
- **引导强度 (sample_guide_scale)**: 生成质量控制,1.0-20.0
- **推理步数 (sample_steps)**: 生成精度,10-100步
- **随机种子 (base_seed)**: 结果重现,-1为随机
**优化建议**:
- 首次使用建议保持默认参数
- 如果显存不足,可以降低分辨率和帧数
- 提示词使用英文效果更好
- 音频文件建议使用清晰的语音或音乐
**注意事项**:
- 生成时间取决于参数设置,通常需要5-10分钟
- 确保上传的图片和音频文件格式正确
- 如果遇到错误,请检查参数设置和文件格式
""")
# 事件绑定
load_btn.click(load_model_interface, outputs=load_status)
generate_btn.click(
generate_interface,
inputs=[
task, size, prompt, image, audio,
num_frames, guidance_scale, num_inference_steps,
seed, offload_model, convert_model_dtype
],
outputs=[output, status]
)
# 启动应用
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)