HoneyTian commited on
Commit
d03c698
1 Parent(s): d9b0161
Files changed (3) hide show
  1. Dockerfile +1 -1
  2. main.py +7 -3
  3. toolbox/k2_sherpa/nn_models.py +37 -11
Dockerfile CHANGED
@@ -9,7 +9,7 @@ COPY . /code/
9
 
10
  RUN pip install --upgrade pip
11
 
12
- RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple
13
 
14
  # libk2_torch_api.so
15
  RUN export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.8/site-packages/k2/lib/
 
9
 
10
  RUN pip install --upgrade pip
11
 
12
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
13
 
14
  # libk2_torch_api.so
15
  RUN export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.8/site-packages/k2/lib/
main.py CHANGED
@@ -4,13 +4,17 @@ import argparse
4
  from collections import defaultdict
5
  from datetime import datetime
6
  import functools
7
- import io
8
  import logging
 
9
  from pathlib import Path
10
  import platform
11
  import time
12
  import tempfile
13
 
 
 
 
 
14
  from project_settings import project_path, log_directory
15
  import log
16
 
@@ -109,8 +113,8 @@ def process(
109
  nn_model_file=nn_model_file.as_posix(),
110
  tokens_file=tokens_file.as_posix(),
111
  sub_folder=m_dict["sub_folder"],
112
- local_model_dir=local_model_dir,
113
- recognizer_type=m_dict["recognizer_type"],
114
  decoding_method=decoding_method,
115
  num_active_paths=num_active_paths,
116
  )
 
4
  from collections import defaultdict
5
  from datetime import datetime
6
  import functools
 
7
  import logging
8
+ import os
9
  from pathlib import Path
10
  import platform
11
  import time
12
  import tempfile
13
 
14
+ os.system(
15
+ "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib/python3.8/site-packages/k2/lib/"
16
+ )
17
+
18
  from project_settings import project_path, log_directory
19
  import log
20
 
 
113
  nn_model_file=nn_model_file.as_posix(),
114
  tokens_file=tokens_file.as_posix(),
115
  sub_folder=m_dict["sub_folder"],
116
+ local_model_dir=local_model_dir.as_posix(),
117
+ loader=m_dict["loader"],
118
  decoding_method=decoding_method,
119
  num_active_paths=num_active_paths,
120
  )
toolbox/k2_sherpa/nn_models.py CHANGED
@@ -6,6 +6,7 @@ import os
6
 
7
  import huggingface_hub
8
  import sherpa
 
9
 
10
 
11
  class EnumDecodingMethod(Enum):
@@ -13,13 +14,6 @@ class EnumDecodingMethod(Enum):
13
  modified_beam_search = "modified_beam_search"
14
 
15
 
16
- class EnumRecognizerType(Enum):
17
- sherpa_offline_recognizer = "sherpa.OfflineRecognizer"
18
- sherpa_online_recognizer = "sherpa.OnlineRecognizer"
19
- sherpa_onnx_offline_recognizer = "sherpa_onnx.OfflineRecognizer"
20
- sherpa_onnx_online_recognizer = "sherpa_onnx.OnlineRecognizer"
21
-
22
-
23
  model_map = {
24
  "Chinese": [
25
  {
@@ -27,7 +21,14 @@ model_map = {
27
  "nn_model_file": "final.zip",
28
  "tokens_file": "units.txt",
29
  "sub_folder": ".",
30
- "recognizer_type": EnumRecognizerType.sherpa_offline_recognizer.value,
 
 
 
 
 
 
 
31
  }
32
  ]
33
  }
@@ -83,12 +84,31 @@ def load_sherpa_offline_recognizer(nn_model_file: str,
83
  return recognizer
84
 
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  def load_recognizer(repo_id: str,
87
  nn_model_file: str,
88
  tokens_file: str,
89
  sub_folder: str,
90
  local_model_dir: str,
91
- recognizer_type: str,
92
  decoding_method: str = "greedy_search",
93
  num_active_paths: int = 4,
94
  ):
@@ -101,15 +121,21 @@ def load_recognizer(repo_id: str,
101
  local_model_dir=local_model_dir,
102
  )
103
 
104
- if recognizer_type == EnumRecognizerType.sherpa_offline_recognizer.value:
105
  recognizer = load_sherpa_offline_recognizer(
106
  nn_model_file=nn_model_file,
107
  tokens_file=tokens_file,
108
  decoding_method=decoding_method,
109
  num_active_paths=num_active_paths,
110
  )
 
 
 
 
 
 
111
  else:
112
- raise NotImplementedError("recognizer_type not support: {}".format(recognizer_type))
113
  return recognizer
114
 
115
 
 
6
 
7
  import huggingface_hub
8
  import sherpa
9
+ import sherpa_onnx
10
 
11
 
12
  class EnumDecodingMethod(Enum):
 
14
  modified_beam_search = "modified_beam_search"
15
 
16
 
 
 
 
 
 
 
 
17
  model_map = {
18
  "Chinese": [
19
  {
 
21
  "nn_model_file": "final.zip",
22
  "tokens_file": "units.txt",
23
  "sub_folder": ".",
24
+ "loader": "load_sherpa_offline_recognizer",
25
+ },
26
+ {
27
+ "repo_id": "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28",
28
+ "nn_model_file": "model.int8.onnx",
29
+ "tokens_file": "tokens.txt",
30
+ "sub_folder": ".",
31
+ "loader": "load_sherpa_offline_recognizer_from_paraformer",
32
  }
33
  ]
34
  }
 
84
  return recognizer
85
 
86
 
87
+ def load_sherpa_offline_recognizer_from_paraformer(nn_model_file: str,
88
+ tokens_file: str,
89
+ sample_rate: int = 16000,
90
+ decoding_method: str = "greedy_search",
91
+ feature_dim: int = 80,
92
+ num_threads: int = 2,
93
+ ):
94
+ recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
95
+ paraformer=nn_model_file,
96
+ tokens=tokens_file,
97
+ num_threads=num_threads,
98
+ sample_rate=sample_rate,
99
+ feature_dim=feature_dim,
100
+ decoding_method=decoding_method,
101
+ debug=False,
102
+ )
103
+ return recognizer
104
+
105
+
106
  def load_recognizer(repo_id: str,
107
  nn_model_file: str,
108
  tokens_file: str,
109
  sub_folder: str,
110
  local_model_dir: str,
111
+ loader: str,
112
  decoding_method: str = "greedy_search",
113
  num_active_paths: int = 4,
114
  ):
 
121
  local_model_dir=local_model_dir,
122
  )
123
 
124
+ if loader == "load_sherpa_offline_recognizer":
125
  recognizer = load_sherpa_offline_recognizer(
126
  nn_model_file=nn_model_file,
127
  tokens_file=tokens_file,
128
  decoding_method=decoding_method,
129
  num_active_paths=num_active_paths,
130
  )
131
+ elif loader == "load_sherpa_offline_recognizer_from_paraformer":
132
+ recognizer = load_sherpa_offline_recognizer_from_paraformer(
133
+ nn_model_file=nn_model_file,
134
+ tokens_file=tokens_file,
135
+ decoding_method=decoding_method,
136
+ )
137
  else:
138
+ raise NotImplementedError("loader not support: {}".format(loader))
139
  return recognizer
140
 
141