uvr5 / download.py
lorneluo's picture
clean UVR import
14a97f2
raw history blame
No virus
3.27 kB
import json
import os
import sys
import urllib
from pprint import pprint
import wget
from tqdm import tqdm
from gui_data.constants import DOWNLOAD_CHECKS, NORMAL_REPO, UPDATE_REPO
BASE_PATH = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(BASE_PATH, 'models')
VR_MODELS_DIR = os.path.join(MODELS_DIR, 'VR_Models')
MDX_MODELS_DIR = os.path.join(MODELS_DIR, 'MDX_Net_Models')
DEMUCS_MODELS_DIR = os.path.join(MODELS_DIR, 'Demucs_Models')
DEMUCS_NEWER_REPO_DIR = os.path.join(DEMUCS_MODELS_DIR, 'v3_v4_repo')
online_model_data = json.load(urllib.request.urlopen(DOWNLOAD_CHECKS))
mdx_download_list = {
**online_model_data["mdx_download_list"],
**online_model_data["mdx23c_download_list"],
**online_model_data["mdx23_download_list"],
# **online_model_data["mdx_download_vip_list"],
# **online_model_data["mdx23c_download_vip_list"],
}
vr_download_list = online_model_data["vr_download_list"]
demucs_download_list = online_model_data["demucs_download_list"]
def get_mdx_model_file(model):
return get_mdx_model_filelist(model)[0][0]
def get_mdx_model_filelist(model):
filename = mdx_download_list[model]
if isinstance(filename, dict):
model_name = list(filename.keys())[0]
else:
model_name = str(filename)
model_path = os.path.join(MDX_MODELS_DIR, model_name)
url = f"{NORMAL_REPO}{model_name}"
return [(model_path, url)]
def get_vr_model_file(model):
return get_vr_model_filelist(model)[0][0]
def get_vr_model_filelist(model):
filename = vr_download_list[model]
url = f"{NORMAL_REPO}{filename}"
model_path = os.path.join(VR_MODELS_DIR, filename)
return [(model_path, url)]
def get_demucs_model_file(model):
for filename, url in get_demucs_model_filelist(model):
if filename.lower().endswith('.yaml'):
return filename
def get_demucs_model_filelist(model):
download_demucs_newer_models = []
for filename, url in demucs_download_list[model].items():
model_path = os.path.join(DEMUCS_NEWER_REPO_DIR, filename)
download_demucs_newer_models.append((model_path, url))
return download_demucs_newer_models
def get_model_file(model_name):
if model_name in mdx_download_list:
model_path = get_mdx_model_file(model_name)
elif model_name in vr_download_list:
model_path = get_vr_model_file(model_name)
elif model_name in demucs_download_list:
model_path = get_demucs_model_file(model_name)
else:
raise FileNotFoundError(f"Can't found model {model_name}")
return model_path
def download_model(model_name):
if model_name in mdx_download_list:
filelist = get_mdx_model_filelist(model_name)
elif model_name in vr_download_list:
filelist = get_vr_model_filelist(model_name)
elif model_name in demucs_download_list:
filelist = get_demucs_model_filelist(model_name)
else:
raise FileNotFoundError(f"Can't found model {model_name}")
for model_path, url in filelist:
if os.path.isfile(model_path):
return
print(f'Downloading from {url} to {model_path}')
wget.download(url, model_path)
if __name__ == '__main__':
model_name = sys.argv[1]
download_model(model_name)