ASLP-lab commited on
Commit
952133e
·
1 Parent(s): 5a5fa7d

fix import bug

Browse files
src/SongFormer/utils/fetch_pretrained.py CHANGED
@@ -4,6 +4,7 @@ from tqdm import tqdm
4
 
5
 
6
  def download(url, path):
 
7
  if os.path.exists(path):
8
  print(f"File already exists, skipping download: {path}")
9
  return
@@ -25,16 +26,31 @@ def download(url, path):
25
  bar.update(size)
26
 
27
 
28
- # 根据 https://github.com/minzwon/musicfm 下载预训练模型
29
- download(
30
- "https://huggingface.co/minzwon/MusicFM/resolve/main/msd_stats.json",
31
- os.path.join("ckpts", "MusicFM", "msd_stats.json"),
32
- )
33
- download(
34
- "https://huggingface.co/minzwon/MusicFM/resolve/main/pretrained_msd.pt",
35
- os.path.join("ckpts", "MusicFM", "pretrained_msd.pt"),
36
- )
37
 
38
- # for Mainland China
39
- # download('https://hf-mirror.com/minzwon/MusicFM/resolve/main/msd_stats.json', os.path.join("ckpts", "MusicFM", "msd_stats.json"))
40
- # download('https://hf-mirror.com/minzwon/MusicFM/resolve/main/pretrained_msd.pt', os.path.join("ckpts", "MusicFM", "pretrained_msd.pt"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  def download(url, path):
7
+ """Download file from url to local path with progress bar."""
8
  if os.path.exists(path):
9
  print(f"File already exists, skipping download: {path}")
10
  return
 
26
  bar.update(size)
27
 
28
 
29
+ def download_all(use_mirror: bool = False):
30
+ """Download all required checkpoints.
31
+
32
+ Args:
33
+ use_mirror (bool): If True, use hf-mirror.com (for Mainland China).
34
+ """
35
+ base_url = "https://hf-mirror.com" if use_mirror else "https://huggingface.co"
 
 
36
 
37
+ urls = [
38
+ (f"{base_url}/minzwon/MusicFM/resolve/main/msd_stats.json",
39
+ os.path.join("ckpts", "MusicFM", "msd_stats.json")),
40
+ (f"{base_url}/minzwon/MusicFM/resolve/main/pretrained_msd.pt",
41
+ os.path.join("ckpts", "MusicFM", "pretrained_msd.pt")),
42
+ (f"{base_url}/ASLP-lab/SongFormer/resolve/main/SongFormer.safetensors",
43
+ os.path.join("ckpts", "SongFormer.safetensors")),
44
+
45
+ # The content of safetensors is the same as pt, it is recommended to use safetensors
46
+ # (f"{base_url}/ASLP-lab/SongFormer/resolve/main/SongFormer.pt",
47
+ # os.path.join("ckpts", "SongFormer.pt")),
48
+ ]
49
+
50
+ for url, path in urls:
51
+ download(url, path)
52
+
53
+
54
+ if __name__ == "__main__":
55
+ # By default, use HuggingFace. If you are in Mainland China, change to True
56
+ download_all(use_mirror=False)