uvr5 / download.py
lorneluo's picture
clean UVR import
14a97f2
raw
history blame
3.27 kB
import json
import os
import sys
import urllib
from pprint import pprint
import wget
from tqdm import tqdm
from gui_data.constants import DOWNLOAD_CHECKS, NORMAL_REPO, UPDATE_REPO
BASE_PATH = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(BASE_PATH, 'models')
VR_MODELS_DIR = os.path.join(MODELS_DIR, 'VR_Models')
MDX_MODELS_DIR = os.path.join(MODELS_DIR, 'MDX_Net_Models')
DEMUCS_MODELS_DIR = os.path.join(MODELS_DIR, 'Demucs_Models')
DEMUCS_NEWER_REPO_DIR = os.path.join(DEMUCS_MODELS_DIR, 'v3_v4_repo')
online_model_data = json.load(urllib.request.urlopen(DOWNLOAD_CHECKS))
mdx_download_list = {
**online_model_data["mdx_download_list"],
**online_model_data["mdx23c_download_list"],
**online_model_data["mdx23_download_list"],
# **online_model_data["mdx_download_vip_list"],
# **online_model_data["mdx23c_download_vip_list"],
}
vr_download_list = online_model_data["vr_download_list"]
demucs_download_list = online_model_data["demucs_download_list"]
def get_mdx_model_file(model):
return get_mdx_model_filelist(model)[0][0]
def get_mdx_model_filelist(model):
filename = mdx_download_list[model]
if isinstance(filename, dict):
model_name = list(filename.keys())[0]
else:
model_name = str(filename)
model_path = os.path.join(MDX_MODELS_DIR, model_name)
url = f"{NORMAL_REPO}{model_name}"
return [(model_path, url)]
def get_vr_model_file(model):
return get_vr_model_filelist(model)[0][0]
def get_vr_model_filelist(model):
filename = vr_download_list[model]
url = f"{NORMAL_REPO}{filename}"
model_path = os.path.join(VR_MODELS_DIR, filename)
return [(model_path, url)]
def get_demucs_model_file(model):
for filename, url in get_demucs_model_filelist(model):
if filename.lower().endswith('.yaml'):
return filename
def get_demucs_model_filelist(model):
download_demucs_newer_models = []
for filename, url in demucs_download_list[model].items():
model_path = os.path.join(DEMUCS_NEWER_REPO_DIR, filename)
download_demucs_newer_models.append((model_path, url))
return download_demucs_newer_models
def get_model_file(model_name):
if model_name in mdx_download_list:
model_path = get_mdx_model_file(model_name)
elif model_name in vr_download_list:
model_path = get_vr_model_file(model_name)
elif model_name in demucs_download_list:
model_path = get_demucs_model_file(model_name)
else:
raise FileNotFoundError(f"Can't found model {model_name}")
return model_path
def download_model(model_name):
if model_name in mdx_download_list:
filelist = get_mdx_model_filelist(model_name)
elif model_name in vr_download_list:
filelist = get_vr_model_filelist(model_name)
elif model_name in demucs_download_list:
filelist = get_demucs_model_filelist(model_name)
else:
raise FileNotFoundError(f"Can't found model {model_name}")
for model_path, url in filelist:
if os.path.isfile(model_path):
return
print(f'Downloading from {url} to {model_path}')
wget.download(url, model_path)
if __name__ == '__main__':
model_name = sys.argv[1]
download_model(model_name)