File size: 3,266 Bytes
83c8e0b
 
 
 
 
 
 
 
 
 
 
14a97f2
 
 
 
 
 
 
83c8e0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import json
import os
import sys
import urllib
from pprint import pprint

import wget
from tqdm import tqdm

from gui_data.constants import DOWNLOAD_CHECKS, NORMAL_REPO, UPDATE_REPO

BASE_PATH = os.path.dirname(os.path.abspath(__file__))
MODELS_DIR = os.path.join(BASE_PATH, 'models')
VR_MODELS_DIR = os.path.join(MODELS_DIR, 'VR_Models')
MDX_MODELS_DIR = os.path.join(MODELS_DIR, 'MDX_Net_Models')
DEMUCS_MODELS_DIR = os.path.join(MODELS_DIR, 'Demucs_Models')
DEMUCS_NEWER_REPO_DIR = os.path.join(DEMUCS_MODELS_DIR, 'v3_v4_repo')

online_model_data = json.load(urllib.request.urlopen(DOWNLOAD_CHECKS))
mdx_download_list = {
    **online_model_data["mdx_download_list"],
    **online_model_data["mdx23c_download_list"],
    **online_model_data["mdx23_download_list"],
    # **online_model_data["mdx_download_vip_list"],
    # **online_model_data["mdx23c_download_vip_list"],
}
vr_download_list = online_model_data["vr_download_list"]
demucs_download_list = online_model_data["demucs_download_list"]


def get_mdx_model_file(model):
    return get_mdx_model_filelist(model)[0][0]


def get_mdx_model_filelist(model):
    filename = mdx_download_list[model]
    if isinstance(filename, dict):
        model_name = list(filename.keys())[0]
    else:
        model_name = str(filename)
    model_path = os.path.join(MDX_MODELS_DIR, model_name)
    url = f"{NORMAL_REPO}{model_name}"

    return [(model_path, url)]


def get_vr_model_file(model):
    return get_vr_model_filelist(model)[0][0]


def get_vr_model_filelist(model):
    filename = vr_download_list[model]
    url = f"{NORMAL_REPO}{filename}"
    model_path = os.path.join(VR_MODELS_DIR, filename)
    return [(model_path, url)]


def get_demucs_model_file(model):
    for filename, url in get_demucs_model_filelist(model):
        if filename.lower().endswith('.yaml'):
            return filename


def get_demucs_model_filelist(model):
    download_demucs_newer_models = []
    for filename, url in demucs_download_list[model].items():
        model_path = os.path.join(DEMUCS_NEWER_REPO_DIR, filename)
        download_demucs_newer_models.append((model_path, url))
    return download_demucs_newer_models


def get_model_file(model_name):
    if model_name in mdx_download_list:
        model_path = get_mdx_model_file(model_name)
    elif model_name in vr_download_list:
        model_path = get_vr_model_file(model_name)
    elif model_name in demucs_download_list:
        model_path = get_demucs_model_file(model_name)
    else:
        raise FileNotFoundError(f"Can't found model {model_name}")
    return model_path


def download_model(model_name):
    if model_name in mdx_download_list:
        filelist = get_mdx_model_filelist(model_name)
    elif model_name in vr_download_list:
        filelist = get_vr_model_filelist(model_name)
    elif model_name in demucs_download_list:
        filelist = get_demucs_model_filelist(model_name)
    else:
        raise FileNotFoundError(f"Can't found model {model_name}")

    for model_path, url in filelist:
        if os.path.isfile(model_path):
            return
        print(f'Downloading from {url} to {model_path}')
        wget.download(url, model_path)


if __name__ == '__main__':
    model_name = sys.argv[1]
    download_model(model_name)