File size: 2,844 Bytes
5ce1fe8
 
 
941850e
5ce1fe8
941850e
 
 
 
5ce1fe8
 
 
 
 
 
 
 
589e655
 
 
 
 
 
 
 
 
 
 
5ce1fe8
 
589e655
5ce1fe8
 
589e655
 
 
 
5ce1fe8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3147eb6
 
 
d8e6dc5
3147eb6
 
 
 
d8e6dc5
3147eb6
 
 
 
 
 
5ce1fe8
 
 
 
 
589e655
3147eb6
5ce1fe8
 
234de07
5ce1fe8
 
159e07d
5ce1fe8
 
589e655
e924ab6
5ce1fe8
159e07d
5ce1fe8
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
import os
from pathlib import Path
import sys

pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))

import huggingface_hub

from project_settings import project_path


def get_args():
    parser = argparse.ArgumentParser()
    # parser.add_argument(
    #     "--repo_id",
    #     default="csukuangfj/wenet-chinese-model",
    #     # default="csukuangfj/wenet-english-model",
    #     type=str
    # )
    # parser.add_argument("--model_filename", default="final.zip", type=str)
    # parser.add_argument("--model_sub_folder", default=".", type=str)
    # parser.add_argument("--tokens_filename", default="units.txt", type=str)
    # parser.add_argument("--tokens_sub_folder", default=".", type=str)

    parser.add_argument(
        "--repo_id",
        default="luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
        type=str
    )
    parser.add_argument("--model_filename", default="cpu_jit_epoch_10_avg_2_torch_1.7.1.pt", type=str)
    parser.add_argument("--model_sub_folder", default="exp", type=str)
    parser.add_argument("--tokens_filename", default="tokens.txt", type=str)
    parser.add_argument("--tokens_sub_folder", default="data/lang_char", type=str)

    parser.add_argument(
        "--pretrained_model_dir",
        default=(project_path / "pretrained_models").as_posix(),
        type=str
    )
    args = parser.parse_args()
    return args


def main():
    args = get_args()

    pretrained_model_dir = Path(args.pretrained_model_dir)
    pretrained_model_dir.mkdir(exist_ok=True)

    repo_id: Path = Path(args.repo_id)
    if len(repo_id.parts) == 1:
        repo_name = repo_id.parts[-1]
        repo_name = repo_name[:30]
        folder = repo_name
    elif len(repo_id.parts) == 2:
        repo_supplier = repo_id.parts[-2]
        repo_name = repo_id.parts[-1]
        repo_name = repo_name[:30]
        folder = "{}/{}".format(repo_supplier, repo_name)
    else:
        raise AssertionError("repo_id parts count invalid: {}".format(len(repo_id.parts)))

    local_model_dir = pretrained_model_dir / "huggingface" / folder
    local_model_dir.mkdir(parents=True, exist_ok=True)

    print("download model")
    model_filename = huggingface_hub.hf_hub_download(
        repo_id=args.repo_id,
        filename=args.model_filename,
        subfolder=args.model_sub_folder,
        local_dir=local_model_dir.as_posix(),
    )
    print(model_filename)
    exit(0)

    print("download tokens")
    tokens_filename = huggingface_hub.hf_hub_download(
        repo_id=args.repo_id,
        filename=args.tokens_filename,
        subfolder=args.tokens_sub_folder,
        local_dir=local_model_dir.as_posix(),
    )
    print(tokens_filename)
    return


if __name__ == "__main__":
    main()