|
|
|
|
|
import argparse |
|
import os |
|
from pathlib import Path |
|
import sys |
|
|
|
pwd = os.path.abspath(os.path.dirname(__file__)) |
|
sys.path.append(os.path.join(pwd, "../../")) |
|
|
|
import huggingface_hub |
|
|
|
from project_settings import project_path |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser.add_argument( |
|
"--repo_id", |
|
default="luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2", |
|
type=str |
|
) |
|
parser.add_argument("--model_filename", default="cpu_jit_epoch_10_avg_2_torch_1.7.1.pt", type=str) |
|
parser.add_argument("--model_sub_folder", default="exp", type=str) |
|
parser.add_argument("--tokens_filename", default="tokens.txt", type=str) |
|
parser.add_argument("--tokens_sub_folder", default="data/lang_char", type=str) |
|
|
|
parser.add_argument( |
|
"--pretrained_model_dir", |
|
default=(project_path / "pretrained_models").as_posix(), |
|
type=str |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
pretrained_model_dir = Path(args.pretrained_model_dir) |
|
pretrained_model_dir.mkdir(exist_ok=True) |
|
|
|
repo_id: Path = Path(args.repo_id) |
|
if len(repo_id.parts) == 1: |
|
repo_name = repo_id.parts[-1] |
|
repo_name = repo_name[:30] |
|
folder = repo_name |
|
elif len(repo_id.parts) == 2: |
|
repo_supplier = repo_id.parts[-2] |
|
repo_name = repo_id.parts[-1] |
|
repo_name = repo_name[:30] |
|
folder = "{}/{}".format(repo_supplier, repo_name) |
|
else: |
|
raise AssertionError("repo_id parts count invalid: {}".format(len(repo_id.parts))) |
|
|
|
local_model_dir = pretrained_model_dir / "huggingface" / folder |
|
local_model_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
print("download model") |
|
model_filename = huggingface_hub.hf_hub_download( |
|
repo_id=args.repo_id, |
|
filename=args.model_filename, |
|
subfolder=args.model_sub_folder, |
|
local_dir=local_model_dir.as_posix(), |
|
) |
|
print(model_filename) |
|
exit(0) |
|
|
|
print("download tokens") |
|
tokens_filename = huggingface_hub.hf_hub_download( |
|
repo_id=args.repo_id, |
|
filename=args.tokens_filename, |
|
subfolder=args.tokens_sub_folder, |
|
local_dir=local_model_dir.as_posix(), |
|
) |
|
print(tokens_filename) |
|
return |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|