|
import argparse |
|
import json |
|
import shutil |
|
from pathlib import Path |
|
|
|
import yaml |
|
from huggingface_hub import hf_hub_download |
|
|
|
from style_bert_vits2.logging import logger |
|
|
|
|
|
def download_bert_models(): |
|
with open("bert/bert_models.json", encoding="utf-8") as fp: |
|
models = json.load(fp) |
|
for k, v in models.items(): |
|
local_path = Path("bert").joinpath(k) |
|
for file in v["files"]: |
|
if not Path(local_path).joinpath(file).exists(): |
|
logger.info(f"Downloading {k} {file}") |
|
hf_hub_download(v["repo_id"], file, local_dir=local_path) |
|
|
|
|
|
def download_slm_model(): |
|
local_path = Path("slm/wavlm-base-plus/") |
|
file = "pytorch_model.bin" |
|
if not Path(local_path).joinpath(file).exists(): |
|
logger.info(f"Downloading wavlm-base-plus {file}") |
|
hf_hub_download("microsoft/wavlm-base-plus", file, local_dir=local_path) |
|
|
|
|
|
def download_pretrained_models(): |
|
files = ["G_0.safetensors", "D_0.safetensors", "DUR_0.safetensors"] |
|
local_path = Path("pretrained") |
|
for file in files: |
|
if not Path(local_path).joinpath(file).exists(): |
|
logger.info(f"Downloading pretrained {file}") |
|
hf_hub_download( |
|
"litagin/Style-Bert-VITS2-1.0-base", file, local_dir=local_path |
|
) |
|
|
|
|
|
def download_jp_extra_pretrained_models(): |
|
files = ["G_0.safetensors", "D_0.safetensors", "WD_0.safetensors"] |
|
local_path = Path("pretrained_jp_extra") |
|
for file in files: |
|
if not Path(local_path).joinpath(file).exists(): |
|
logger.info(f"Downloading JP-Extra pretrained {file}") |
|
hf_hub_download( |
|
"litagin/Style-Bert-VITS2-2.0-base-JP-Extra", file, local_dir=local_path |
|
) |
|
|
|
|
|
def download_default_models(): |
|
files = [ |
|
"jvnv-F1-jp/config.json", |
|
"jvnv-F1-jp/jvnv-F1-jp_e160_s14000.safetensors", |
|
"jvnv-F1-jp/style_vectors.npy", |
|
"jvnv-F2-jp/config.json", |
|
"jvnv-F2-jp/jvnv-F2_e166_s20000.safetensors", |
|
"jvnv-F2-jp/style_vectors.npy", |
|
"jvnv-M1-jp/config.json", |
|
"jvnv-M1-jp/jvnv-M1-jp_e158_s14000.safetensors", |
|
"jvnv-M1-jp/style_vectors.npy", |
|
"jvnv-M2-jp/config.json", |
|
"jvnv-M2-jp/jvnv-M2-jp_e159_s17000.safetensors", |
|
"jvnv-M2-jp/style_vectors.npy", |
|
] |
|
for file in files: |
|
if not Path(f"model_assets/{file}").exists(): |
|
logger.info(f"Downloading {file}") |
|
hf_hub_download( |
|
"litagin/style_bert_vits2_jvnv", |
|
file, |
|
local_dir="model_assets", |
|
) |
|
additional_files = { |
|
"litagin/sbv2_koharune_ami": [ |
|
"koharune-ami/config.json", |
|
"koharune-ami/style_vectors.npy", |
|
"koharune-ami/koharune-ami.safetensors", |
|
], |
|
"litagin/sbv2_amitaro": [ |
|
"amitaro/config.json", |
|
"amitaro/style_vectors.npy", |
|
"amitaro/amitaro.safetensors", |
|
], |
|
} |
|
for repo_id, files in additional_files.items(): |
|
for file in files: |
|
if not Path(f"model_assets/{file}").exists(): |
|
logger.info(f"Downloading {file}") |
|
hf_hub_download( |
|
repo_id, |
|
file, |
|
local_dir="model_assets", |
|
) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--skip_default_models", action="store_true") |
|
parser.add_argument("--only_infer", action="store_true") |
|
parser.add_argument( |
|
"--dataset_root", |
|
type=str, |
|
help="Dataset root path (default: Data)", |
|
default=None, |
|
) |
|
parser.add_argument( |
|
"--assets_root", |
|
type=str, |
|
help="Assets root path (default: model_assets)", |
|
default=None, |
|
) |
|
args = parser.parse_args() |
|
|
|
download_bert_models() |
|
|
|
if not args.skip_default_models: |
|
download_default_models() |
|
if not args.only_infer: |
|
download_slm_model() |
|
download_pretrained_models() |
|
download_jp_extra_pretrained_models() |
|
|
|
|
|
default_paths_yml = Path("configs/default_paths.yml") |
|
paths_yml = Path("configs/paths.yml") |
|
if not paths_yml.exists(): |
|
shutil.copy(default_paths_yml, paths_yml) |
|
|
|
if args.dataset_root is None and args.assets_root is None: |
|
return |
|
|
|
|
|
with open(paths_yml, encoding="utf-8") as f: |
|
yml_data = yaml.safe_load(f) |
|
if args.assets_root is not None: |
|
yml_data["assets_root"] = args.assets_root |
|
if args.dataset_root is not None: |
|
yml_data["dataset_root"] = args.dataset_root |
|
with open(paths_yml, "w", encoding="utf-8") as f: |
|
yaml.dump(yml_data, f, allow_unicode=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|