Spaces:
Running
on
Zero
Running
on
Zero
fix import bug
Browse files
src/SongFormer/utils/fetch_pretrained.py
CHANGED
|
@@ -4,6 +4,7 @@ from tqdm import tqdm
|
|
| 4 |
|
| 5 |
|
| 6 |
def download(url, path):
|
|
|
|
| 7 |
if os.path.exists(path):
|
| 8 |
print(f"File already exists, skipping download: {path}")
|
| 9 |
return
|
|
@@ -25,16 +26,31 @@ def download(url, path):
|
|
| 25 |
bar.update(size)
|
| 26 |
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
"https://huggingface.co
|
| 35 |
-
os.path.join("ckpts", "MusicFM", "pretrained_msd.pt"),
|
| 36 |
-
)
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
def download(url, path):
|
| 7 |
+
"""Download file from url to local path with progress bar."""
|
| 8 |
if os.path.exists(path):
|
| 9 |
print(f"File already exists, skipping download: {path}")
|
| 10 |
return
|
|
|
|
| 26 |
bar.update(size)
|
| 27 |
|
| 28 |
|
| 29 |
+
def download_all(use_mirror: bool = False):
|
| 30 |
+
"""Download all required checkpoints.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
use_mirror (bool): If True, use hf-mirror.com (for Mainland China).
|
| 34 |
+
"""
|
| 35 |
+
base_url = "https://hf-mirror.com" if use_mirror else "https://huggingface.co"
|
|
|
|
|
|
|
| 36 |
|
| 37 |
+
urls = [
|
| 38 |
+
(f"{base_url}/minzwon/MusicFM/resolve/main/msd_stats.json",
|
| 39 |
+
os.path.join("ckpts", "MusicFM", "msd_stats.json")),
|
| 40 |
+
(f"{base_url}/minzwon/MusicFM/resolve/main/pretrained_msd.pt",
|
| 41 |
+
os.path.join("ckpts", "MusicFM", "pretrained_msd.pt")),
|
| 42 |
+
(f"{base_url}/ASLP-lab/SongFormer/resolve/main/SongFormer.safetensors",
|
| 43 |
+
os.path.join("ckpts", "SongFormer.safetensors")),
|
| 44 |
+
|
| 45 |
+
# The content of safetensors is the same as pt, it is recommended to use safetensors
|
| 46 |
+
# (f"{base_url}/ASLP-lab/SongFormer/resolve/main/SongFormer.pt",
|
| 47 |
+
# os.path.join("ckpts", "SongFormer.pt")),
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
for url, path in urls:
|
| 51 |
+
download(url, path)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
# By default, use HuggingFace. If you are in Mainland China, change to True
|
| 56 |
+
download_all(use_mirror=False)
|