dangthr commited on
Commit
a713267
·
verified ·
1 Parent(s): d3adf4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -9
app.py CHANGED
@@ -46,16 +46,13 @@ def filter_kwargs(cls, kwargs):
46
  def resolve_path(user_path, repo_root):
47
  """
48
  以正確的優先級解析檔案路徑。
49
- 1. 檢查是否為絕對路徑或本地存在的相對路徑。
50
- 2. 如果找不到,則嘗試將其視為相對於 Hugging Face 快取根目錄的路徑。
51
- 3. 如果都不是,則返回 None。
52
  """
53
- # 優先檢查本地路徑(絕對或相對)
54
  if os.path.exists(user_path):
55
  print(f"找到本地檔案: {os.path.abspath(user_path)}")
56
  return os.path.abspath(user_path)
57
 
58
- # 如果本地找不到,再嘗試從 HF 快取目錄中尋找
59
  potential_repo_path = os.path.join(repo_root, user_path)
60
  if os.path.exists(potential_repo_path):
61
  print(f"在 Hugging Face 快取目錄中找到檔案: {potential_repo_path}")
@@ -69,9 +66,10 @@ def setup_models(repo_root, model_version):
69
  pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
70
  pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
71
 
 
72
  config_path = os.path.join(repo_root, "deepspeed_config/wan2.1/wan_civitai.yaml")
73
  if not os.path.exists(config_path):
74
- raise FileNotFoundError(f"設定檔未找到: {config_path}")
75
  config = OmegaConf.load(config_path)
76
  sampler_name = "Flow"
77
 
@@ -203,12 +201,18 @@ def main():
203
  parser.add_argument('--model_version', type=str, default="square", choices=["square", "rec_vec"], help='StableAvatar 模型版本')
204
  args = parser.parse_args()
205
 
206
- print("--- 步驟 1: 正在檢查並下載模型 ---")
207
- # 這一步仍然是必要的,以確保模型權重檔存在於快取中
208
  repo_root = snapshot_download(
209
  repo_id="FrancisRing/StableAvatar",
210
- allow_patterns=["StableAvatar-1.3B/*", "Wan2.1-Fun-V1.1-1.3B-InP/*", "wav2vec2-base-960h/*", "deepspeed_config/**"],
 
 
 
 
 
211
  )
 
212
  print("模型檔案已準備就緒。")
213
 
214
  print("\n--- 步驟 2: 正在解析輸入檔案路徑 ---")
@@ -223,6 +227,7 @@ def main():
223
  return
224
 
225
  print("\n--- 步驟 3: 正在載入模型 ---")
 
226
  pipeline, transformer3d, vae = setup_models(repo_root, args.model_version)
227
  print("模型載入完成。")
228
 
 
46
  def resolve_path(user_path, repo_root):
47
  """
48
  以正確的優先級解析檔案路徑。
49
+ 1. 優先檢查本地路徑(絕對或相對)。
50
+ 2. 如果找不到,則嘗試從 HF 快取目錄中尋找。
 
51
  """
 
52
  if os.path.exists(user_path):
53
  print(f"找到本地檔案: {os.path.abspath(user_path)}")
54
  return os.path.abspath(user_path)
55
 
 
56
  potential_repo_path = os.path.join(repo_root, user_path)
57
  if os.path.exists(potential_repo_path):
58
  print(f"在 Hugging Face 快取目錄中找到檔案: {potential_repo_path}")
 
66
  pretrained_model_name_or_path = os.path.join(repo_root, "Wan2.1-Fun-V1.1-1.3B-InP")
67
  pretrained_wav2vec_path = os.path.join(repo_root, "wav2vec2-base-960h")
68
 
69
+ # 這個路徑現在應該可以正確找到了
70
  config_path = os.path.join(repo_root, "deepspeed_config/wan2.1/wan_civitai.yaml")
71
  if not os.path.exists(config_path):
72
+ raise FileNotFoundError(f"設定檔未找到: {config_path},請檢查 snapshot_download 是否已下載此檔案。")
73
  config = OmegaConf.load(config_path)
74
  sampler_name = "Flow"
75
 
 
201
  parser.add_argument('--model_version', type=str, default="square", choices=["square", "rec_vec"], help='StableAvatar 模型版本')
202
  args = parser.parse_args()
203
 
204
+ print("--- 步驟 1: 正在檢查並下載模型與設定檔 ---")
205
+ # <<< 核心修正:加入 'deepspeed_config/**' 來下載設定檔 >>>
206
  repo_root = snapshot_download(
207
  repo_id="FrancisRing/StableAvatar",
208
+ allow_patterns=[
209
+ "StableAvatar-1.3B/*",
210
+ "Wan2.1-Fun-V1.1-1.3B-InP/*",
211
+ "wav2vec2-base-960h/*",
212
+ "deepspeed_config/**" # <-- 修正點
213
+ ],
214
  )
215
+ # <<< 修正結束 >>>
216
  print("模型檔案已準備就緒。")
217
 
218
  print("\n--- 步驟 2: 正在解析輸入檔案路徑 ---")
 
227
  return
228
 
229
  print("\n--- 步驟 3: 正在載入模型 ---")
230
+ # 將 repo_root 傳遞給 setup_models,這樣它才能在正確的位置找到設定檔
231
  pipeline, transformer3d, vae = setup_models(repo_root, args.model_version)
232
  print("模型載入完成。")
233