import json import os import sys import urllib from pprint import pprint import wget from tqdm import tqdm from gui_data.constants import DOWNLOAD_CHECKS, NORMAL_REPO, UPDATE_REPO BASE_PATH = os.path.dirname(os.path.abspath(__file__)) MODELS_DIR = os.path.join(BASE_PATH, 'models') VR_MODELS_DIR = os.path.join(MODELS_DIR, 'VR_Models') MDX_MODELS_DIR = os.path.join(MODELS_DIR, 'MDX_Net_Models') DEMUCS_MODELS_DIR = os.path.join(MODELS_DIR, 'Demucs_Models') DEMUCS_NEWER_REPO_DIR = os.path.join(DEMUCS_MODELS_DIR, 'v3_v4_repo') online_model_data = json.load(urllib.request.urlopen(DOWNLOAD_CHECKS)) mdx_download_list = { **online_model_data["mdx_download_list"], **online_model_data["mdx23c_download_list"], **online_model_data["mdx23_download_list"], # **online_model_data["mdx_download_vip_list"], # **online_model_data["mdx23c_download_vip_list"], } vr_download_list = online_model_data["vr_download_list"] demucs_download_list = online_model_data["demucs_download_list"] def get_mdx_model_file(model): return get_mdx_model_filelist(model)[0][0] def get_mdx_model_filelist(model): filename = mdx_download_list[model] if isinstance(filename, dict): model_name = list(filename.keys())[0] else: model_name = str(filename) model_path = os.path.join(MDX_MODELS_DIR, model_name) url = f"{NORMAL_REPO}{model_name}" return [(model_path, url)] def get_vr_model_file(model): return get_vr_model_filelist(model)[0][0] def get_vr_model_filelist(model): filename = vr_download_list[model] url = f"{NORMAL_REPO}{filename}" model_path = os.path.join(VR_MODELS_DIR, filename) return [(model_path, url)] def get_demucs_model_file(model): for filename, url in get_demucs_model_filelist(model): if filename.lower().endswith('.yaml'): return filename def get_demucs_model_filelist(model): download_demucs_newer_models = [] for filename, url in demucs_download_list[model].items(): model_path = os.path.join(DEMUCS_NEWER_REPO_DIR, filename) download_demucs_newer_models.append((model_path, url)) return download_demucs_newer_models def get_model_file(model_name): if model_name in mdx_download_list: model_path = get_mdx_model_file(model_name) elif model_name in vr_download_list: model_path = get_vr_model_file(model_name) elif model_name in demucs_download_list: model_path = get_demucs_model_file(model_name) else: raise FileNotFoundError(f"Can't found model {model_name}") return model_path def download_model(model_name): if model_name in mdx_download_list: filelist = get_mdx_model_filelist(model_name) elif model_name in vr_download_list: filelist = get_vr_model_filelist(model_name) elif model_name in demucs_download_list: filelist = get_demucs_model_filelist(model_name) else: raise FileNotFoundError(f"Can't found model {model_name}") for model_path, url in filelist: if os.path.isfile(model_path): return print(f'Downloading from {url} to {model_path}') wget.download(url, model_path) if __name__ == '__main__': model_name = sys.argv[1] download_model(model_name)