dangthr commited on
Commit
c5c8aa3
·
verified ·
1 Parent(s): a2a9a31

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -154
app.py CHANGED
@@ -44,7 +44,7 @@ def filter_kwargs(cls, kwargs):
44
  return filtered_kwargs
45
 
46
  def download_file(url, local_path):
47
- """從 URL 下載檔案"""
48
  if url.startswith(('http://', 'https://')):
49
  print(f"從 {url} 下載檔案中...")
50
  try:
@@ -65,12 +65,15 @@ def download_file(url, local_path):
65
  print(f"錯誤:檔案或 URL 不存在: {url}")
66
  return None
67
 
68
- def setup_models(repo_root):
69
  """載入所有必要的模型和設定"""
70
  pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
71
  pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
72
 
73
- config = OmegaConf.load("deepspeed_config/wan2.1/wan_civitai.yaml")
 
 
 
74
  sampler_name = "Flow"
75
 
76
  print("正在載入 Tokenizer...")
@@ -97,7 +100,7 @@ def setup_models(repo_root):
97
  print("正在載入 CLIP Image Encoder...")
98
  clip_image_encoder = CLIPModel.from_pretrained(os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder'))).eval()
99
 
100
- print("正在載入 Transformer 3D...")
101
  transformer3d = WanTransformer3DFantasyModel.from_pretrained(
102
  os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
103
  transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
@@ -105,99 +108,55 @@ def setup_models(repo_root):
105
  torch_dtype=dtype,
106
  )
107
 
108
- scheduler_class = {
109
- "Flow": FlowMatchEulerDiscreteScheduler,
110
- }[sampler_name]
111
- scheduler = scheduler_class(
112
- **filter_kwargs(scheduler_class, OmegaConf.to_container(config['scheduler_kwargs']))
113
- )
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
  print("正在建立 Pipeline...")
116
  pipeline = WanI2VTalkingInferenceLongPipeline(
117
- tokenizer=tokenizer,
118
- text_encoder=text_encoder,
119
- vae=vae,
120
- transformer=transformer3d,
121
- clip_image_encoder=clip_image_encoder,
122
- scheduler=scheduler,
123
- wav2vec_processor=wav2vec_processor,
124
- wav2vec=wav2vec,
125
  )
126
 
127
  return pipeline, transformer3d, vae
128
 
129
  def run_inference(
130
- pipeline,
131
- transformer3d,
132
- vae,
133
- image_path,
134
- audio_path,
135
- prompt,
136
- negative_prompt,
137
- seed,
138
- output_filename,
139
- gpu_memory_mode="model_cpu_offload",
140
- teacache_threshold=0,
141
- num_skip_start_steps=5,
142
- width=512,
143
- height=512,
144
- guidance_scale=6.0,
145
- num_inference_steps=50,
146
- text_guide_scale=3.0,
147
- audio_guide_scale=5.0,
148
- motion_frame=25,
149
- fps=25,
150
- overlap_window_length=10,
151
- overlapping_weight_scheme="uniform",
152
- clip_sample_n_frames=81,
153
  ):
154
- """
155
- 執行推理以生成影片。
156
-
157
- Args:
158
- pipeline: 推理 pipeline。
159
- transformer3d: 3D transformer 模型。
160
- vae: VAE 模型。
161
- image_path (str): 輸入圖片的路徑。
162
- audio_path (str): 輸入音訊的路徑。
163
- prompt (str): 正面提示詞。
164
- negative_prompt (str): 負面提示詞。
165
- seed (int): 隨機種子,-1 表示隨機。
166
- output_filename (str): 輸出影片的檔案名稱(不含副檔名)。
167
- ... 其他生成參數
168
- """
169
  if seed < 0:
170
  seed = random.randint(0, np.iinfo(np.int32).max)
171
  print(f"使用的種子: {seed}")
172
 
173
- # --- 記憶體優化設定 ---
174
  if gpu_memory_mode == "sequential_cpu_offload":
175
- replace_parameters_by_name(transformer3d, ["modulation", ], device=device)
176
- transformer3d.freqs = transformer3d.freqs.to(device=device)
177
  pipeline.enable_sequential_cpu_offload(device=device)
178
- elif gpu_memory_mode == "model_cpu_offload_and_qfloat8":
179
- convert_model_weight_to_float8(transformer3d, exclude_module_name=["modulation", ])
180
- convert_weight_dtype_wrapper(transformer3d, dtype)
181
- pipeline.enable_model_cpu_offload(device=device)
182
  elif gpu_memory_mode == "model_cpu_offload":
183
  pipeline.enable_model_cpu_offload(device=device)
184
  else:
185
  pipeline.to(device=device)
186
 
187
- # --- TeaCache 加速 ---
188
- if teacache_threshold > 0:
189
- coefficients = get_teacache_coefficients(pipeline.transformer.config._name_or_path)
190
- pipeline.transformer.enable_teacache(
191
- coefficients,
192
- num_inference_steps,
193
- teacache_threshold,
194
- num_skip_start_steps=num_skip_start_steps,
195
- )
196
-
197
- # --- 開始推理 ---
198
  with torch.no_grad():
199
  print("正在準備輸入資料...")
200
- video_length = int((clip_sample_n_frames - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if clip_sample_n_frames != 1 else 1
 
201
  input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width])
202
 
203
  sr = 16000
@@ -205,27 +164,13 @@ def run_inference(
205
 
206
  print("Pipeline 執行中... 這可能需要一些時間。")
207
  sample = pipeline(
208
- prompt,
209
- num_frames=video_length,
210
- negative_prompt=negative_prompt,
211
- width=width,
212
- height=height,
213
- guidance_scale=guidance_scale,
214
- generator=torch.Generator().manual_seed(seed),
215
- num_inference_steps=num_inference_steps,
216
- video=input_video,
217
- mask_video=input_video_mask,
218
- clip_image=clip_image,
219
- text_guide_scale=text_guide_scale,
220
- audio_guide_scale=audio_guide_scale,
221
- vocal_input_values=vocal_input,
222
- motion_frame=motion_frame,
223
- fps=fps,
224
- sr=sr,
225
- cond_file_path=image_path,
226
- overlap_window_length=overlap_window_length,
227
- seed=seed,
228
- overlapping_weight_scheme=overlapping_weight_scheme,
229
  ).videos
230
 
231
  print("正在儲存影片...")
@@ -242,7 +187,6 @@ def run_inference(
242
  output_video_with_audio
243
  ], check=True)
244
 
245
- # 刪除無音訊的暫存影片
246
  os.remove(video_path)
247
 
248
  print(f"✅ 生成完成!影片已儲存至: {output_video_with_audio}")
@@ -250,92 +194,69 @@ def run_inference(
250
 
251
  def main():
252
  parser = argparse.ArgumentParser(description="StableAvatar 命令列推理工具")
253
-
254
- # --- 主要參數 ---
255
  parser.add_argument('--prompt', type=str, default="a beautiful woman is talking, masterpiece, best quality", help='正面提示詞')
256
- parser.add_argument('--input_image', type=str, default="./example_case/case-1/reference.png", help='輸入圖片的路徑或 URL')
257
- parser.add_argument('--input_audio', type=str, default="./example_case/case-1/audio.wav", help='輸入音訊的路徑或 URL')
258
  parser.add_argument('--seed', type=int, default=42, help='隨機種子,-1 表示隨機')
259
-
260
- # --- 生成參數 ---
261
  parser.add_argument('--negative_prompt', type=str, default="vivid color, static, blur details, text, style, painting, picture, still, gray, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, malformed, deformed, bad anatomy, fused fingers, still image, messy background, many people in the background, walking backwards", help='負面提示詞')
262
  parser.add_argument('--width', type=int, default=512, help='影片寬度')
263
  parser.add_argument('--height', type=int, default=512, help='影片高度')
264
  parser.add_argument('--num_inference_steps', type=int, default=50, help='推理步數')
265
  parser.add_argument('--fps', type=int, default=25, help='影片幀率')
266
-
267
- # --- 模型與優化參數 ---
268
- parser.add_argument('--repo_id', type=str, default="FrancisRing/StableAvatar", help='Hugging Face 模型的 Repo ID')
269
- parser.add_argument('--gpu_memory_mode', type=str, default="model_cpu_offload", choices=["Normal", "model_cpu_offload", "model_cpu_offloadand_qfloat8", "sequential_cpu_offload"], help='GPU 記憶體優化模式')
270
-
271
  args = parser.parse_args()
272
 
273
- # --- 1. 下載模型 ---
274
  print("--- 步驟 1: 正在檢查並下載模型 ---")
275
- REPO_ID = args.repo_id
276
  repo_root = snapshot_download(
277
- repo_id=REPO_ID,
278
- allow_patterns=[
279
- "StableAvatar-1.3B/*",
280
- "Wan2.1-Fun-V1.1-1.3B-InP/*",
281
- "wav2vec2-base-960h/*",
282
- "assets/**",
283
- "Kim_Vocal_2.onnx",
284
- "example_case/**", # 確保範例檔案被下載
285
- "deepspeed_config/**",
286
- ],
287
  )
288
  print("模型檔案已準備就緒。")
289
 
290
- # --- 2. 處理輸入檔案 ---
291
  print("\n--- 步驟 2: 正在處理輸入檔案 ---")
292
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
293
-
294
- # 建立暫存目錄
295
  temp_dir = f"temp_{timestamp}"
296
  os.makedirs(temp_dir, exist_ok=True)
297
 
298
- # 處理圖片
299
- image_ext = os.path.splitext(args.input_image)[1] or '.png'
300
- local_image_path = os.path.join(temp_dir, f"input_image{image_ext}")
301
- final_image_path = download_file(args.input_image, local_image_path)
 
 
 
 
 
302
  if not final_image_path:
303
- shutil.rmtree(temp_dir)
304
- return
305
 
306
- # 處理音訊
307
- audio_ext = os.path.splitext(args.input_audio)[1] or '.wav'
308
- local_audio_path = os.path.join(temp_dir, f"input_audio{audio_ext}")
309
- final_audio_path = download_file(args.input_audio, local_audio_path)
 
 
 
310
  if not final_audio_path:
311
- shutil.rmtree(temp_dir)
312
- return
313
 
314
- # --- 3. 載入模型 ---
315
- print("\n--- 步驟 3: 正在載入模型(這可能需要一些時間) ---")
316
- pipeline, transformer3d, vae = setup_models(repo_root)
317
  print("模型載入完成。")
318
 
319
- # --- 4. 執行推理 ---
320
  print("\n--- 步驟 4: 開始執行推理 ---")
321
  run_inference(
322
- pipeline=pipeline,
323
- transformer3d=transformer3d,
324
- vae=vae,
325
- image_path=final_image_path,
326
- audio_path=final_audio_path,
327
- prompt=args.prompt,
328
- negative_prompt=args.negative_prompt,
329
- seed=args.seed,
330
- output_filename=f"output_{timestamp}",
331
- gpu_memory_mode=args.gpu_memory_mode,
332
- width=args.width,
333
- height=args.height,
334
- num_inference_steps=args.num_inference_steps,
335
- fps=args.fps,
336
  )
337
 
338
- # --- 5. 清理 ---
339
  print("\n--- 步驟 5: 清理暫存檔案 ---")
340
  try:
341
  shutil.rmtree(temp_dir)
@@ -345,3 +266,4 @@ def main():
345
 
346
  if __name__ == "__main__":
347
  main()
 
 
44
  return filtered_kwargs
45
 
46
  def download_file(url, local_path):
47
+ """從 URL 下載檔案,如果 URL 是本地路徑則直接返回"""
48
  if url.startswith(('http://', 'https://')):
49
  print(f"從 {url} 下載檔案中...")
50
  try:
 
65
  print(f"錯誤:檔案或 URL 不存在: {url}")
66
  return None
67
 
68
+ def setup_models(repo_root, model_version):
69
  """載入所有必要的模型和設定"""
70
  pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
71
  pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
72
 
73
+ config_path = os.path.join(repo_root, "deepspeed_config/wan2.1/wan_civitai.yaml")
74
+ if not os.path.exists(config_path):
75
+ raise FileNotFoundError(f"設定檔未找到: {config_path}")
76
+ config = OmegaConf.load(config_path)
77
  sampler_name = "Flow"
78
 
79
  print("正在載入 Tokenizer...")
 
100
  print("正在載入 CLIP Image Encoder...")
101
  clip_image_encoder = CLIPModel.from_pretrained(os.path.join(pretrained_model_name_or_path, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder'))).eval()
102
 
103
+ print("正在載入 Transformer 3D 基礎模型...")
104
  transformer3d = WanTransformer3DFantasyModel.from_pretrained(
105
  os.path.join(pretrained_model_name_or_path, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
106
  transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
 
108
  torch_dtype=dtype,
109
  )
110
 
111
+ # <<< FIX 1: 載入 StableAvatar 專用權重 >>>
112
+ if model_version == "square":
113
+ transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-square.pt")
114
+ else: # rec_vec
115
+ transformer_path = os.path.join(repo_root, "StableAvatar-1.3B", "transformer3d-rec-vec.pt")
116
+
117
+ if os.path.exists(transformer_path):
118
+ print(f"正在從 {transformer_path} 載入 StableAvatar 權重...")
119
+ state_dict = torch.load(transformer_path, map_location="cpu")
120
+ state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
121
+ m, u = transformer3d.load_state_dict(state_dict, strict=False)
122
+ print(f"StableAvatar 權重載入成功。 Missing keys: {len(m)}; Unexpected keys: {len(u)}")
123
+ else:
124
+ raise FileNotFoundError(f"找不到 StableAvatar 權重檔案:{transformer_path}。請確保模型已完整下載。")
125
+ # <<< END OF FIX 1 >>>
126
+
127
+ scheduler_class = { "Flow": FlowMatchEulerDiscreteScheduler }[sampler_name]
128
+ scheduler = scheduler_class(**filter_kwargs(scheduler_class, OmegaConf.to_container(config['scheduler_kwargs'])))
129
 
130
  print("正在建立 Pipeline...")
131
  pipeline = WanI2VTalkingInferenceLongPipeline(
132
+ tokenizer=tokenizer, text_encoder=text_encoder, vae=vae,
133
+ transformer=transformer3d, clip_image_encoder=clip_image_encoder,
134
+ scheduler=scheduler, wav2vec_processor=wav2vec_processor, wav2vec=wav2vec,
 
 
 
 
 
135
  )
136
 
137
  return pipeline, transformer3d, vae
138
 
139
  def run_inference(
140
+ pipeline, transformer3d, vae, image_path, audio_path, prompt,
141
+ negative_prompt, seed, output_filename, gpu_memory_mode="model_cpu_offload",
142
+ width=512, height=512, num_inference_steps=50, fps=25, **kwargs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  ):
144
+ """執行推理以生成影片。"""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  if seed < 0:
146
  seed = random.randint(0, np.iinfo(np.int32).max)
147
  print(f"使用的種子: {seed}")
148
 
 
149
  if gpu_memory_mode == "sequential_cpu_offload":
 
 
150
  pipeline.enable_sequential_cpu_offload(device=device)
 
 
 
 
151
  elif gpu_memory_mode == "model_cpu_offload":
152
  pipeline.enable_model_cpu_offload(device=device)
153
  else:
154
  pipeline.to(device=device)
155
 
 
 
 
 
 
 
 
 
 
 
 
156
  with torch.no_grad():
157
  print("正在準備輸入資料...")
158
+ # 由於 get_image_to_video_latent 內部有自己的 vae.config 引用,所以此處警告可忽略
159
+ video_length = 81
160
  input_video, input_video_mask, clip_image = get_image_to_video_latent(image_path, None, video_length=video_length, sample_size=[height, width])
161
 
162
  sr = 16000
 
164
 
165
  print("Pipeline 執行中... 這可能需要一些時間。")
166
  sample = pipeline(
167
+ prompt, num_frames=video_length, negative_prompt=negative_prompt,
168
+ width=width, height=height, guidance_scale=6.0,
169
+ generator=torch.Generator().manual_seed(seed), num_inference_steps=num_inference_steps,
170
+ video=input_video, mask_video=input_video_mask, clip_image=clip_image,
171
+ text_guide_scale=3.0, audio_guide_scale=5.0, vocal_input_values=vocal_input,
172
+ motion_frame=25, fps=fps, sr=sr, cond_file_path=image_path,
173
+ overlap_window_length=10, seed=seed, overlapping_weight_scheme="uniform",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  ).videos
175
 
176
  print("正在儲存影片...")
 
187
  output_video_with_audio
188
  ], check=True)
189
 
 
190
  os.remove(video_path)
191
 
192
  print(f"✅ 生成完成!影片已儲存至: {output_video_with_audio}")
 
194
 
195
  def main():
196
  parser = argparse.ArgumentParser(description="StableAvatar 命令列推理工具")
 
 
197
  parser.add_argument('--prompt', type=str, default="a beautiful woman is talking, masterpiece, best quality", help='正面提示詞')
198
+ parser.add_argument('--input_image', type=str, default="example_case/case-6/reference.png", help='輸入圖片的路徑或 URL')
199
+ parser.add_argument('--input_audio', type=str, default="example_case/case-6/audio.wav", help='輸入音訊的路徑或 URL')
200
  parser.add_argument('--seed', type=int, default=42, help='隨機種子,-1 表示隨機')
 
 
201
  parser.add_argument('--negative_prompt', type=str, default="vivid color, static, blur details, text, style, painting, picture, still, gray, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, out of frame, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, malformed, deformed, bad anatomy, fused fingers, still image, messy background, many people in the background, walking backwards", help='負面提示詞')
202
  parser.add_argument('--width', type=int, default=512, help='影片寬度')
203
  parser.add_argument('--height', type=int, default=512, help='影片高度')
204
  parser.add_argument('--num_inference_steps', type=int, default=50, help='推理步數')
205
  parser.add_argument('--fps', type=int, default=25, help='影片幀率')
206
+ parser.add_argument('--gpu_memory_mode', type=str, default="model_cpu_offload", choices=["Normal", "model_cpu_offload"], help='GPU 記憶體優化模式')
207
+ parser.add_argument('--model_version', type=str, default="square", choices=["square", "rec_vec"], help='StableAvatar 模型版本')
 
 
 
208
  args = parser.parse_args()
209
 
 
210
  print("--- 步驟 1: 正在檢查並下載模型 ---")
 
211
  repo_root = snapshot_download(
212
+ repo_id="FrancisRing/StableAvatar",
213
+ allow_patterns=["StableAvatar-1.3B/*", "Wan2.1-Fun-V1.1-1.3B-InP/*", "wav2vec2-base-960h/*", "example_case/**", "deepspeed_config/**"],
 
 
 
 
 
 
 
 
214
  )
215
  print("模型檔案已準備就緒。")
216
 
 
217
  print("\n--- 步驟 2: 正在處理輸入檔案 ---")
218
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
 
 
219
  temp_dir = f"temp_{timestamp}"
220
  os.makedirs(temp_dir, exist_ok=True)
221
 
222
+ # <<< FIX 2: 穩健的路徑處理 >>>
223
+ # 處理圖片路徑
224
+ input_image_path = args.input_image
225
+ # 如果不是 URL 且不是絕對路徑,就視為相對於 repo_root 的路徑
226
+ if not input_image_path.startswith(('http', '/')):
227
+ input_image_path = os.path.join(repo_root, input_image_path)
228
+
229
+ local_image_path = os.path.join(temp_dir, os.path.basename(input_image_path))
230
+ final_image_path = download_file(input_image_path, local_image_path)
231
  if not final_image_path:
232
+ shutil.rmtree(temp_dir); return
 
233
 
234
+ # 處理音訊路徑
235
+ input_audio_path = args.input_audio
236
+ if not input_audio_path.startswith(('http', '/')):
237
+ input_audio_path = os.path.join(repo_root, input_audio_path)
238
+
239
+ local_audio_path = os.path.join(temp_dir, os.path.basename(input_audio_path))
240
+ final_audio_path = download_file(input_audio_path, local_audio_path)
241
  if not final_audio_path:
242
+ shutil.rmtree(temp_dir); return
243
+ # <<< END OF FIX 2 >>>
244
 
245
+ print("\n--- 步驟 3: 正在載入模型 ---")
246
+ pipeline, transformer3d, vae = setup_models(repo_root, args.model_version)
 
247
  print("模型載入完成。")
248
 
 
249
  print("\n--- 步驟 4: 開始執行推理 ---")
250
  run_inference(
251
+ pipeline=pipeline, transformer3d=transformer3d, vae=vae,
252
+ image_path=final_image_path, audio_path=final_audio_path,
253
+ prompt=args.prompt, negative_prompt=args.negative_prompt,
254
+ seed=args.seed, output_filename=f"output_{timestamp}",
255
+ gpu_memory_mode=args.gpu_memory_mode, width=args.width,
256
+ height=args.height, num_inference_steps=args.num_inference_steps,
257
+ fps=args.fps
 
 
 
 
 
 
 
258
  )
259
 
 
260
  print("\n--- 步驟 5: 清理暫存檔案 ---")
261
  try:
262
  shutil.rmtree(temp_dir)
 
266
 
267
  if __name__ == "__main__":
268
  main()
269
+