lorneluo commited on
Commit
83c8e0b
1 Parent(s): 40258be

model downloader

Browse files
Files changed (1) hide show
  1. download.py +97 -0
download.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import urllib
5
+ from pprint import pprint
6
+
7
+ import wget
8
+ from tqdm import tqdm
9
+
10
+ from UVR import DEMUCS_NEWER_REPO_DIR, VR_MODELS_DIR, MDX_MODELS_DIR
11
+ from gui_data.constants import DOWNLOAD_CHECKS, NORMAL_REPO, UPDATE_REPO
12
+
13
+ online_model_data = json.load(urllib.request.urlopen(DOWNLOAD_CHECKS))
14
+ mdx_download_list = {
15
+ **online_model_data["mdx_download_list"],
16
+ **online_model_data["mdx23c_download_list"],
17
+ **online_model_data["mdx23_download_list"],
18
+ # **online_model_data["mdx_download_vip_list"],
19
+ # **online_model_data["mdx23c_download_vip_list"],
20
+ }
21
+ vr_download_list = online_model_data["vr_download_list"]
22
+ demucs_download_list = online_model_data["demucs_download_list"]
23
+
24
+
25
+ def get_mdx_model_file(model):
26
+ return get_mdx_model_filelist(model)[0][0]
27
+
28
+
29
+ def get_mdx_model_filelist(model):
30
+ filename = mdx_download_list[model]
31
+ if isinstance(filename, dict):
32
+ model_name = list(filename.keys())[0]
33
+ else:
34
+ model_name = str(filename)
35
+ model_path = os.path.join(MDX_MODELS_DIR, model_name)
36
+ url = f"{NORMAL_REPO}{model_name}"
37
+
38
+ return [(model_path, url)]
39
+
40
+
41
+ def get_vr_model_file(model):
42
+ return get_vr_model_filelist(model)[0][0]
43
+
44
+
45
+ def get_vr_model_filelist(model):
46
+ filename = vr_download_list[model]
47
+ url = f"{NORMAL_REPO}{filename}"
48
+ model_path = os.path.join(VR_MODELS_DIR, filename)
49
+ return [(model_path, url)]
50
+
51
+
52
+ def get_demucs_model_file(model):
53
+ for filename, url in get_demucs_model_filelist(model):
54
+ if filename.lower().endswith('.yaml'):
55
+ return filename
56
+
57
+
58
+ def get_demucs_model_filelist(model):
59
+ download_demucs_newer_models = []
60
+ for filename, url in demucs_download_list[model].items():
61
+ model_path = os.path.join(DEMUCS_NEWER_REPO_DIR, filename)
62
+ download_demucs_newer_models.append((model_path, url))
63
+ return download_demucs_newer_models
64
+
65
+
66
+ def get_model_file(model_name):
67
+ if model_name in mdx_download_list:
68
+ model_path = get_mdx_model_file(model_name)
69
+ elif model_name in vr_download_list:
70
+ model_path = get_vr_model_file(model_name)
71
+ elif model_name in demucs_download_list:
72
+ model_path = get_demucs_model_file(model_name)
73
+ else:
74
+ raise FileNotFoundError(f"Can't found model {model_name}")
75
+ return model_path
76
+
77
+
78
+ def download_model(model_name):
79
+ if model_name in mdx_download_list:
80
+ filelist = get_mdx_model_filelist(model_name)
81
+ elif model_name in vr_download_list:
82
+ filelist = get_vr_model_filelist(model_name)
83
+ elif model_name in demucs_download_list:
84
+ filelist = get_demucs_model_filelist(model_name)
85
+ else:
86
+ raise FileNotFoundError(f"Can't found model {model_name}")
87
+
88
+ for model_path, url in filelist:
89
+ if os.path.isfile(model_path):
90
+ return
91
+ print(f'Downloading from {url} to {model_path}')
92
+ wget.download(url, model_path)
93
+
94
+
95
+ if __name__ == '__main__':
96
+ model_name = sys.argv[1]
97
+ download_model(model_name)