HoneyTian commited on
Commit
589e655
1 Parent(s): 8f39105
Files changed (2) hide show
  1. examples/wenet/downaload_model.py +19 -8
  2. main.py +1 -0
examples/wenet/downaload_model.py CHANGED
@@ -15,15 +15,26 @@ from project_settings import project_path
15
 
16
  def get_args():
17
  parser = argparse.ArgumentParser()
 
 
 
 
 
 
 
 
 
 
 
18
  parser.add_argument(
19
  "--repo_id",
20
- default="csukuangfj/wenet-chinese-model",
21
- # default="csukuangfj/wenet-english-model",
22
  type=str
23
  )
24
-
25
- parser.add_argument("--model_filename", default="final.zip", type=str)
26
- parser.add_argument("--tokens_filename", default="units.txt", type=str)
 
27
 
28
  parser.add_argument(
29
  "--pretrained_model_dir",
@@ -41,13 +52,13 @@ def main():
41
  pretrained_model_dir.mkdir(exist_ok=True)
42
 
43
  model_dir = pretrained_model_dir / "huggingface" / args.repo_id
44
- model_dir.mkdir(exist_ok=True)
45
 
46
  print("download model")
47
  model_filename = huggingface_hub.hf_hub_download(
48
  repo_id=args.repo_id,
49
  filename=args.model_filename,
50
- subfolder=".",
51
  local_dir=model_dir.as_posix(),
52
  )
53
  print(model_filename)
@@ -56,7 +67,7 @@ def main():
56
  tokens_filename = huggingface_hub.hf_hub_download(
57
  repo_id=args.repo_id,
58
  filename=args.tokens_filename,
59
- subfolder=".",
60
  local_dir=model_dir.as_posix(),
61
  )
62
  print(tokens_filename)
 
15
 
16
  def get_args():
17
  parser = argparse.ArgumentParser()
18
+ # parser.add_argument(
19
+ # "--repo_id",
20
+ # default="csukuangfj/wenet-chinese-model",
21
+ # # default="csukuangfj/wenet-english-model",
22
+ # type=str
23
+ # )
24
+ # parser.add_argument("--model_filename", default="final.zip", type=str)
25
+ # parser.add_argument("--model_sub_folder", default=".", type=str)
26
+ # parser.add_argument("--tokens_filename", default="units.txt", type=str)
27
+ # parser.add_argument("--tokens_sub_folder", default=".", type=str)
28
+
29
  parser.add_argument(
30
  "--repo_id",
31
+ default="luomingshuang/icefall_asr_wenetspeech_pruned_transducer_stateless2",
 
32
  type=str
33
  )
34
+ parser.add_argument("--model_filename", default="cpu_jit_epoch_10_avg_2_torch_1.7.1.pt", type=str)
35
+ parser.add_argument("--model_sub_folder", default="exp", type=str)
36
+ parser.add_argument("--tokens_filename", default="tokens.txt", type=str)
37
+ parser.add_argument("--tokens_sub_folder", default="data/lang_char", type=str)
38
 
39
  parser.add_argument(
40
  "--pretrained_model_dir",
 
52
  pretrained_model_dir.mkdir(exist_ok=True)
53
 
54
  model_dir = pretrained_model_dir / "huggingface" / args.repo_id
55
+ model_dir.mkdir(parents=True, exist_ok=True)
56
 
57
  print("download model")
58
  model_filename = huggingface_hub.hf_hub_download(
59
  repo_id=args.repo_id,
60
  filename=args.model_filename,
61
+ subfolder=args.model_sub_folder,
62
  local_dir=model_dir.as_posix(),
63
  )
64
  print(model_filename)
 
67
  tokens_filename = huggingface_hub.hf_hub_download(
68
  repo_id=args.repo_id,
69
  filename=args.tokens_filename,
70
+ subfolder=args.tokens_sub_folder,
71
  local_dir=model_dir.as_posix(),
72
  )
73
  print(tokens_filename)
main.py CHANGED
@@ -101,6 +101,7 @@ def process(
101
 
102
  # load recognizer
103
  local_model_dir = pretrained_model_dir / "huggingface" / repo_id
 
104
 
105
  recognizer = nn_models.load_recognizer(
106
  local_model_dir=local_model_dir,
 
101
 
102
  # load recognizer
103
  local_model_dir = pretrained_model_dir / "huggingface" / repo_id
104
+ local_model_dir.mkdir(parents=True, exist_ok=True)
105
 
106
  recognizer = nn_models.load_recognizer(
107
  local_model_dir=local_model_dir,