Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -44,7 +44,7 @@ def filter_kwargs(cls, kwargs):
|
|
| 44 |
return filtered_kwargs
|
| 45 |
|
| 46 |
def download_file(url, local_path):
|
| 47 |
-
"""從 URL
|
| 48 |
if url.startswith(('http://', 'https://')):
|
| 49 |
print(f"從 {url} 下載檔案中...")
|
| 50 |
try:
|
|
@@ -65,12 +65,15 @@ def download_file(url, local_path):
|
|
| 65 |
print(f"錯誤:檔案或 URL 不存在: {url}")
|
| 66 |
return None
|
| 67 |
|
| 68 |
-
def setup_models(repo_root):
|
| 69 |
"""載入所有必要的模型和設定"""
|
| 70 |
pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
|
| 71 |
pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
|
| 72 |
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
| 74 |
sampler_name = "Flow"
|
| 75 |
|
| 76 |
print("正在載入 Tokenizer...")
|
|
@@ -97,7 +100,7 @@ def setup_models(repo_root):
|
|
| 97 |
print("正在載入 CLIP Image Encoder...")
|
| 98 |
clip_image_encoder = CLIPModel.from_pretrained(os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder'))).eval()
|
| 99 |
|
| 100 |
-
print("正在載入 Transformer 3D
|
| 101 |
transformer3d = WanTransformer3DFantasyModel.from_pretrained(
|
| 102 |
os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
|
| 103 |
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
|
|
@@ -105,99 +108,55 @@ def setup_models(repo_root):
|
|
| 105 |
torch_dtype=dtype,
|
| 106 |
)
|
| 107 |
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
print("正在建立 Pipeline...")
|
| 116 |
pipeline = WanI2VTalkingInferenceLongPipeline(
|
| 117 |
-
tokenizer=tokenizer,
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
transformer=transformer3d,
|
| 121 |
-
clip_image_encoder=clip_image_encoder,
|
| 122 |
-
scheduler=scheduler,
|
| 123 |
-
wav2vec_processor=wav2vec_processor,
|
| 124 |
-
wav2vec=wav2vec,
|
| 125 |
)
|
| 126 |
|
| 127 |
return pipeline, transformer3d, vae
|
| 128 |
|
| 129 |
def run_inference(
|
| 130 |
-
pipeline,
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
image_path,
|
| 134 |
-
audio_path,
|
| 135 |
-
prompt,
|
| 136 |
-
negative_prompt,
|
| 137 |
-
seed,
|
| 138 |
-
output_filename,
|
| 139 |
-
gpu_memory_mode="model_cpu_offload",
|
| 140 |
-
teacache_threshold=0,
|
| 141 |
-
num_skip_start_steps=5,
|
| 142 |
-
width=512,
|
| 143 |
-
height=512,
|
| 144 |
-
guidance_scale=6.0,
|
| 145 |
-
num_inference_steps=50,
|
| 146 |
-
text_guide_scale=3.0,
|
| 147 |
-
audio_guide_scale=5.0,
|
| 148 |
-
motion_frame=25,
|
| 149 |
-
fps=25,
|
| 150 |
-
overlap_window_length=10,
|
| 151 |
-
overlapping_weight_scheme="uniform",
|
| 152 |
-
clip_sample_n_frames=81,
|
| 153 |
):
|
| 154 |
-
"""
|
| 155 |
-
執行推理以生成影片。
|
| 156 |
-
|
| 157 |
-
Args:
|
| 158 |
-
pipeline: 推理 pipeline。
|
| 159 |
-
transformer3d: 3D transformer 模型。
|
| 160 |
-
vae: VAE 模型。
|
| 161 |
-
image_path (str): 輸入圖片的路徑。
|
| 162 |
-
audio_path (str): 輸入音訊的路徑。
|
| 163 |
-
prompt (str): 正面提示詞。
|
| 164 |
-
negative_prompt (str): 負面提示詞。
|
| 165 |
-
seed (int): 隨機種子,-1 表示隨機。
|
| 166 |
-
output_filename (str): 輸出影片的檔案名稱(不含副檔名)。
|
| 167 |
-
... 其他生成參數
|
| 168 |
-
"""
|
| 169 |
if seed < 0:
|
| 170 |
seed = random.randint(0, np.iinfo(np.int32).max)
|
| 171 |
print(f"使用的種子: {seed}")
|
| 172 |
|
| 173 |
-
# --- 記憶體優化設定 ---
|
| 174 |
if gpu_memory_mode == "sequential_cpu_offload":
|
| 175 |
-
replace_parameters_by_name(transformer3d, ["modulation", ], device=device)
|
| 176 |
-
transformer3d.freqs = transformer3d.freqs.to(device=device)
|
| 177 |
pipeline.enable_sequential_cpu_offload(device=device)
|
| 178 |
-
elif gpu_memory_mode == "model_cpu_offload_and_qfloat8":
|
| 179 |
-
convert_model_weight_to_float8(transformer3d, exclude_module_name=["modulation", ])
|
| 180 |
-
convert_weight_dtype_wrapper(transformer3d, dtype)
|
| 181 |
-
pipeline.enable_model_cpu_offload(device=device)
|
| 182 |
elif gpu_memory_mode == "model_cpu_offload":
|
| 183 |
pipeline.enable_model_cpu_offload(device=device)
|
| 184 |
else:
|
| 185 |
pipeline.to(device=device)
|
| 186 |
|
| 187 |
-
# --- TeaCache 加速 ---
|
| 188 |
-
if teacache_threshold > 0:
|
| 189 |
-
coefficients = get_teacache_coefficients(pipeline.transformer.config._name_or_path)
|
| 190 |
-
pipeline.transformer.enable_teacache(
|
| 191 |
-
coefficients,
|
| 192 |
-
num_inference_steps,
|
| 193 |
-
teacache_threshold,
|
| 194 |
-
num_skip_start_steps=num_skip_start_steps,
|
| 195 |
-
)
|
| 196 |
-
|
| 197 |
-
# --- 開始推理 ---
|
| 198 |
with torch.no_grad():
|
| 199 |
print("正在準備輸入資料...")
|
| 200 |
-
|
|
|
|
| 201 |
input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width])
|
| 202 |
|
| 203 |
sr = 16000
|
|
@@ -205,27 +164,13 @@ def run_inference(
|
|
| 205 |
|
| 206 |
print("Pipeline 執行中... 這可能需要一些時間。")
|
| 207 |
sample = pipeline(
|
| 208 |
-
prompt,
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
num_inference_steps=num_inference_steps,
|
| 216 |
-
video=input_video,
|
| 217 |
-
mask_video=input_video_mask,
|
| 218 |
-
clip_image=clip_image,
|
| 219 |
-
text_guide_scale=text_guide_scale,
|
| 220 |
-
audio_guide_scale=audio_guide_scale,
|
| 221 |
-
vocal_input_values=vocal_input,
|
| 222 |
-
motion_frame=motion_frame,
|
| 223 |
-
fps=fps,
|
| 224 |
-
sr=sr,
|
| 225 |
-
cond_file_path=image_path,
|
| 226 |
-
overlap_window_length=overlap_window_length,
|
| 227 |
-
seed=seed,
|
| 228 |
-
overlapping_weight_scheme=overlapping_weight_scheme,
|
| 229 |
).videos
|
| 230 |
|
| 231 |
print("正在儲存影片...")
|
|
@@ -242,7 +187,6 @@ def run_inference(
|
|
| 242 |
output_video_with_audio
|
| 243 |
], check=True)
|
| 244 |
|
| 245 |
-
# 刪除無音訊的暫存影片
|
| 246 |
os.remove(video_path)
|
| 247 |
|
| 248 |
print(f"✅ 生成完成!影片已儲存至: {output_video_with_audio}")
|
|
@@ -250,92 +194,69 @@ def run_inference(
|
|
| 250 |
|
| 251 |
def main():
|
| 252 |
parser = argparse.ArgumentParser(description="StableAvatar 命令列推理工具")
|
| 253 |
-
|
| 254 |
-
# --- 主要參數 ---
|
| 255 |
parser.add_argument('--prompt', type=str, default="a beautiful woman is talking, masterpiece, best quality", help='正面提示詞')
|
| 256 |
-
parser.add_argument('--input_image', type=str, default="
|
| 257 |
-
parser.add_argument('--input_audio', type=str, default="
|
| 258 |
parser.add_argument('--seed', type=int, default=42, help='隨機種子,-1 表示隨機')
|
| 259 |
-
|
| 260 |
-
# --- 生成參數 ---
|
| 261 |
parser.add_argument('--negative_prompt', type=str, default="vivid color, static, blur details, text, style, painting, picture, still, gray, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, malformed, deformed, bad anatomy, fused fingers, still image, messy background, many people in the background, walking backwards", help='負面提示詞')
|
| 262 |
parser.add_argument('--width', type=int, default=512, help='影片寬度')
|
| 263 |
parser.add_argument('--height', type=int, default=512, help='影片高度')
|
| 264 |
parser.add_argument('--num_inference_steps', type=int, default=50, help='推理步數')
|
| 265 |
parser.add_argument('--fps', type=int, default=25, help='影片幀率')
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
parser.add_argument('--repo_id', type=str, default="FrancisRing/StableAvatar", help='Hugging Face 模型的 Repo ID')
|
| 269 |
-
parser.add_argument('--gpu_memory_mode', type=str, default="model_cpu_offload", choices=["Normal", "model_cpu_offload", "model_cpu_offloadand_qfloat8", "sequential_cpu_offload"], help='GPU 記憶體優化模式')
|
| 270 |
-
|
| 271 |
args = parser.parse_args()
|
| 272 |
|
| 273 |
-
# --- 1. 下載模型 ---
|
| 274 |
print("--- 步驟 1: 正在檢查並下載模型 ---")
|
| 275 |
-
REPO_ID = args.repo_id
|
| 276 |
repo_root = snapshot_download(
|
| 277 |
-
repo_id=
|
| 278 |
-
allow_patterns=[
|
| 279 |
-
"StableAvatar-1.3B/*",
|
| 280 |
-
"Wan2.1-Fun-V1.1-1.3B-InP/*",
|
| 281 |
-
"wav2vec2-base-960h/*",
|
| 282 |
-
"assets/**",
|
| 283 |
-
"Kim_Vocal_2.onnx",
|
| 284 |
-
"example_case/**", # 確保範例檔案被下載
|
| 285 |
-
"deepspeed_config/**",
|
| 286 |
-
],
|
| 287 |
)
|
| 288 |
print("模型檔案已準備就緒。")
|
| 289 |
|
| 290 |
-
# --- 2. 處理輸入檔案 ---
|
| 291 |
print("\n--- 步驟 2: 正在處理輸入檔案 ---")
|
| 292 |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 293 |
-
|
| 294 |
-
# 建立暫存目錄
|
| 295 |
temp_dir = f"temp_{timestamp}"
|
| 296 |
os.makedirs(temp_dir, exist_ok=True)
|
| 297 |
|
| 298 |
-
#
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 302 |
if not final_image_path:
|
| 303 |
-
shutil.rmtree(temp_dir)
|
| 304 |
-
return
|
| 305 |
|
| 306 |
-
#
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
|
|
|
|
|
|
|
|
|
| 310 |
if not final_audio_path:
|
| 311 |
-
shutil.rmtree(temp_dir)
|
| 312 |
-
|
| 313 |
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
pipeline, transformer3d, vae = setup_models(repo_root)
|
| 317 |
print("模型載入完成。")
|
| 318 |
|
| 319 |
-
# --- 4. 執行推理 ---
|
| 320 |
print("\n--- 步驟 4: 開始執行推理 ---")
|
| 321 |
run_inference(
|
| 322 |
-
pipeline=pipeline,
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
seed=args.seed,
|
| 330 |
-
output_filename=f"output_{timestamp}",
|
| 331 |
-
gpu_memory_mode=args.gpu_memory_mode,
|
| 332 |
-
width=args.width,
|
| 333 |
-
height=args.height,
|
| 334 |
-
num_inference_steps=args.num_inference_steps,
|
| 335 |
-
fps=args.fps,
|
| 336 |
)
|
| 337 |
|
| 338 |
-
# --- 5. 清理 ---
|
| 339 |
print("\n--- 步驟 5: 清理暫存檔案 ---")
|
| 340 |
try:
|
| 341 |
shutil.rmtree(temp_dir)
|
|
@@ -345,3 +266,4 @@ def main():
|
|
| 345 |
|
| 346 |
if __name__ == "__main__":
|
| 347 |
main()
|
|
|
|
|
|
| 44 |
return filtered_kwargs
|
| 45 |
|
| 46 |
def download_file(url, local_path):
|
| 47 |
+
"""從 URL 下載檔案,如果 URL 是本地路徑則直接返回"""
|
| 48 |
if url.startswith(('http://', 'https://')):
|
| 49 |
print(f"從 {url} 下載檔案中...")
|
| 50 |
try:
|
|
|
|
| 65 |
print(f"錯誤:檔案或 URL 不存在: {url}")
|
| 66 |
return None
|
| 67 |
|
| 68 |
+
def setup_models(repo_root, model_version):
|
| 69 |
"""載入所有必要的模型和設定"""
|
| 70 |
pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
|
| 71 |
pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
|
| 72 |
|
| 73 |
+
config_path = os.path.join(repo_root, "deepspeed_config/wan2.1/wan_civitai.yaml")
|
| 74 |
+
if not os.path.exists(config_path):
|
| 75 |
+
raise FileNotFoundError(f"設定檔未找到: {config_path}")
|
| 76 |
+
config = OmegaConf.load(config_path)
|
| 77 |
sampler_name = "Flow"
|
| 78 |
|
| 79 |
print("正在載入 Tokenizer...")
|
|
|
|
| 100 |
print("正在載入 CLIP Image Encoder...")
|
| 101 |
clip_image_encoder = CLIPModel.from_pretrained(os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder'))).eval()
|
| 102 |
|
| 103 |
+
print("正在載入 Transformer 3D 基礎模型...")
|
| 104 |
transformer3d = WanTransformer3DFantasyModel.from_pretrained(
|
| 105 |
os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
|
| 106 |
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
|
|
|
|
| 108 |
torch_dtype=dtype,
|
| 109 |
)
|
| 110 |
|
| 111 |
+
# <<< FIX 1: 載入 StableAvatar 專用權重 >>>
|
| 112 |
+
if model_version == "square":
|
| 113 |
+
transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt")
|
| 114 |
+
else: # rec_vec
|
| 115 |
+
transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-rec-vec.pt")
|
| 116 |
+
|
| 117 |
+
if os.path.exists(transformer_path):
|
| 118 |
+
print(f"正在從 {transformer_path} 載入 StableAvatar 權重...")
|
| 119 |
+
state_dict = torch.load(transformer_path, map_location="cpu")
|
| 120 |
+
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
|
| 121 |
+
m, u = transformer3d.load_state_dict(state_dict, strict=False)
|
| 122 |
+
print(f"StableAvatar 權重載入成功。 Missing keys: {len(m)}; Unexpected keys: {len(u)}")
|
| 123 |
+
else:
|
| 124 |
+
raise FileNotFoundError(f"找不到 StableAvatar 權重檔案:{transformer_path}。請確保模型已完整下載。")
|
| 125 |
+
# <<< END OF FIX 1 >>>
|
| 126 |
+
|
| 127 |
+
scheduler_class = { "Flow": FlowMatchEulerDiscreteScheduler }[sampler_name]
|
| 128 |
+
scheduler = scheduler_class(**filter_kwargs(scheduler_class, OmegaConf.to_container(config['scheduler_kwargs'])))
|
| 129 |
|
| 130 |
print("正在建立 Pipeline...")
|
| 131 |
pipeline = WanI2VTalkingInferenceLongPipeline(
|
| 132 |
+
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae,
|
| 133 |
+
transformer=transformer3d, clip_image_encoder=clip_image_encoder,
|
| 134 |
+
scheduler=scheduler, wav2vec_processor=wav2vec_processor, wav2vec=wav2vec,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
)
|
| 136 |
|
| 137 |
return pipeline, transformer3d, vae
|
| 138 |
|
| 139 |
def run_inference(
|
| 140 |
+
pipeline, transformer3d, vae, image_path, audio_path, prompt,
|
| 141 |
+
negative_prompt, seed, output_filename, gpu_memory_mode="model_cpu_offload",
|
| 142 |
+
width=512, height=512, num_inference_steps=50, fps=25, **kwargs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 143 |
):
|
| 144 |
+
"""執行推理以生成影片。"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
if seed < 0:
|
| 146 |
seed = random.randint(0, np.iinfo(np.int32).max)
|
| 147 |
print(f"使用的種子: {seed}")
|
| 148 |
|
|
|
|
| 149 |
if gpu_memory_mode == "sequential_cpu_offload":
|
|
|
|
|
|
|
| 150 |
pipeline.enable_sequential_cpu_offload(device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
elif gpu_memory_mode == "model_cpu_offload":
|
| 152 |
pipeline.enable_model_cpu_offload(device=device)
|
| 153 |
else:
|
| 154 |
pipeline.to(device=device)
|
| 155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
with torch.no_grad():
|
| 157 |
print("正在準備輸入資料...")
|
| 158 |
+
# 由於 get_image_to_video_latent 內部有自己的 vae.config 引用,所以此處警告可忽略
|
| 159 |
+
video_length = 81
|
| 160 |
input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width])
|
| 161 |
|
| 162 |
sr = 16000
|
|
|
|
| 164 |
|
| 165 |
print("Pipeline 執行中... 這可能需要一些時間。")
|
| 166 |
sample = pipeline(
|
| 167 |
+
prompt, num_frames=video_length, negative_prompt=negative_prompt,
|
| 168 |
+
width=width, height=height, guidance_scale=6.0,
|
| 169 |
+
generator=torch.Generator().manual_seed(seed), num_inference_steps=num_inference_steps,
|
| 170 |
+
video=input_video, mask_video=input_video_mask, clip_image=clip_image,
|
| 171 |
+
text_guide_scale=3.0, audio_guide_scale=5.0, vocal_input_values=vocal_input,
|
| 172 |
+
motion_frame=25, fps=fps, sr=sr, cond_file_path=image_path,
|
| 173 |
+
overlap_window_length=10, seed=seed, overlapping_weight_scheme="uniform",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
).videos
|
| 175 |
|
| 176 |
print("正在儲存影片...")
|
|
|
|
| 187 |
output_video_with_audio
|
| 188 |
], check=True)
|
| 189 |
|
|
|
|
| 190 |
os.remove(video_path)
|
| 191 |
|
| 192 |
print(f"✅ 生成完成!影片已儲存至: {output_video_with_audio}")
|
|
|
|
| 194 |
|
| 195 |
def main():
|
| 196 |
parser = argparse.ArgumentParser(description="StableAvatar 命令列推理工具")
|
|
|
|
|
|
|
| 197 |
parser.add_argument('--prompt', type=str, default="a beautiful woman is talking, masterpiece, best quality", help='正面提示詞')
|
| 198 |
+
parser.add_argument('--input_image', type=str, default="example_case/case-6/reference.png", help='輸入圖片的路徑或 URL')
|
| 199 |
+
parser.add_argument('--input_audio', type=str, default="example_case/case-6/audio.wav", help='輸入音訊的路徑或 URL')
|
| 200 |
parser.add_argument('--seed', type=int, default=42, help='隨機種子,-1 表示隨機')
|
|
|
|
|
|
|
| 201 |
parser.add_argument('--negative_prompt', type=str, default="vivid color, static, blur details, text, style, painting, picture, still, gray, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, malformed, deformed, bad anatomy, fused fingers, still image, messy background, many people in the background, walking backwards", help='負面提示詞')
|
| 202 |
parser.add_argument('--width', type=int, default=512, help='影片寬度')
|
| 203 |
parser.add_argument('--height', type=int, default=512, help='影片高度')
|
| 204 |
parser.add_argument('--num_inference_steps', type=int, default=50, help='推理步數')
|
| 205 |
parser.add_argument('--fps', type=int, default=25, help='影片幀率')
|
| 206 |
+
parser.add_argument('--gpu_memory_mode', type=str, default="model_cpu_offload", choices=["Normal", "model_cpu_offload"], help='GPU 記憶體優化模式')
|
| 207 |
+
parser.add_argument('--model_version', type=str, default="square", choices=["square", "rec_vec"], help='StableAvatar 模型版本')
|
|
|
|
|
|
|
|
|
|
| 208 |
args = parser.parse_args()
|
| 209 |
|
|
|
|
| 210 |
print("--- 步驟 1: 正在檢查並下載模型 ---")
|
|
|
|
| 211 |
repo_root = snapshot_download(
|
| 212 |
+
repo_id="FrancisRing/StableAvatar",
|
| 213 |
+
allow_patterns=["StableAvatar-1.3B/*", "Wan2.1-Fun-V1.1-1.3B-InP/*", "wav2vec2-base-960h/*", "example_case/**", "deepspeed_config/**"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
)
|
| 215 |
print("模型檔案已準備就緒。")
|
| 216 |
|
|
|
|
| 217 |
print("\n--- 步驟 2: 正在處理輸入檔案 ---")
|
| 218 |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
|
|
|
|
|
| 219 |
temp_dir = f"temp_{timestamp}"
|
| 220 |
os.makedirs(temp_dir, exist_ok=True)
|
| 221 |
|
| 222 |
+
# <<< FIX 2: 穩健的路徑處理 >>>
|
| 223 |
+
# 處理圖片路徑
|
| 224 |
+
input_image_path = args.input_image
|
| 225 |
+
# 如果不是 URL 且不是絕對路徑,就視為相對於 repo_root 的路徑
|
| 226 |
+
if not input_image_path.startswith(('http', '/')):
|
| 227 |
+
input_image_path = os.path.join(repo_root, input_image_path)
|
| 228 |
+
|
| 229 |
+
local_image_path = os.path.join(temp_dir, os.path.basename(input_image_path))
|
| 230 |
+
final_image_path = download_file(input_image_path, local_image_path)
|
| 231 |
if not final_image_path:
|
| 232 |
+
shutil.rmtree(temp_dir); return
|
|
|
|
| 233 |
|
| 234 |
+
# 處理音訊路徑
|
| 235 |
+
input_audio_path = args.input_audio
|
| 236 |
+
if not input_audio_path.startswith(('http', '/')):
|
| 237 |
+
input_audio_path = os.path.join(repo_root, input_audio_path)
|
| 238 |
+
|
| 239 |
+
local_audio_path = os.path.join(temp_dir, os.path.basename(input_audio_path))
|
| 240 |
+
final_audio_path = download_file(input_audio_path, local_audio_path)
|
| 241 |
if not final_audio_path:
|
| 242 |
+
shutil.rmtree(temp_dir); return
|
| 243 |
+
# <<< END OF FIX 2 >>>
|
| 244 |
|
| 245 |
+
print("\n--- 步驟 3: 正在載入模型 ---")
|
| 246 |
+
pipeline, transformer3d, vae = setup_models(repo_root, args.model_version)
|
|
|
|
| 247 |
print("模型載入完成。")
|
| 248 |
|
|
|
|
| 249 |
print("\n--- 步驟 4: 開始執行推理 ---")
|
| 250 |
run_inference(
|
| 251 |
+
pipeline=pipeline, transformer3d=transformer3d, vae=vae,
|
| 252 |
+
image_path=final_image_path, audio_path=final_audio_path,
|
| 253 |
+
prompt=args.prompt, negative_prompt=args.negative_prompt,
|
| 254 |
+
seed=args.seed, output_filename=f"output_{timestamp}",
|
| 255 |
+
gpu_memory_mode=args.gpu_memory_mode, width=args.width,
|
| 256 |
+
height=args.height, num_inference_steps=args.num_inference_steps,
|
| 257 |
+
fps=args.fps
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 258 |
)
|
| 259 |
|
|
|
|
| 260 |
print("\n--- 步驟 5: 清理暫存檔案 ---")
|
| 261 |
try:
|
| 262 |
shutil.rmtree(temp_dir)
|
|
|
|
| 266 |
|
| 267 |
if __name__ == "__main__":
|
| 268 |
main()
|
| 269 |
+
|