csukuangfj commited on
Commit
6781708
1 Parent(s): ea07244

small fixes

Browse files
Files changed (1) hide show
  1. model.py +149 -0
model.py CHANGED
@@ -14,8 +14,157 @@
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
 
 
 
 
17
  from huggingface_hub import hf_hub_download
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  english_models = {
20
  "whisper-tiny.en": _get_whisper_model,
21
  "whisper-base.en": _get_whisper_model,
 
14
  # See the License for the specific language governing permissions and
15
  # limitations under the License.
16
 
17
+ from functools import lru_cache
18
+
19
+ import sherpa_onnx
20
  from huggingface_hub import hf_hub_download
21
 
22
+ sample_rate = 16000
23
+
24
+
25
+ def _get_nn_model_filename(
26
+ repo_id: str,
27
+ filename: str,
28
+ subfolder: str = "exp",
29
+ ) -> str:
30
+ nn_model_filename = hf_hub_download(
31
+ repo_id=repo_id,
32
+ filename=filename,
33
+ subfolder=subfolder,
34
+ )
35
+ return nn_model_filename
36
+
37
+
38
+ def _get_bpe_model_filename(
39
+ repo_id: str,
40
+ filename: str = "bpe.model",
41
+ subfolder: str = "data/lang_bpe_500",
42
+ ) -> str:
43
+ bpe_model_filename = hf_hub_download(
44
+ repo_id=repo_id,
45
+ filename=filename,
46
+ subfolder=subfolder,
47
+ )
48
+ return bpe_model_filename
49
+
50
+
51
+ def _get_token_filename(
52
+ repo_id: str,
53
+ filename: str = "tokens.txt",
54
+ subfolder: str = "data/lang_char",
55
+ ) -> str:
56
+ token_filename = hf_hub_download(
57
+ repo_id=repo_id,
58
+ filename=filename,
59
+ subfolder=subfolder,
60
+ )
61
+ return token_filename
62
+
63
+
64
+ @lru_cache(maxsize=10)
65
+ def _get_whisper_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer:
66
+ name = repo_id.split("-")[1]
67
+ assert name in ("tiny.en", "base.en", "small.en", "medium.en"), repo_id
68
+ full_repo_id = "csukuangfj/sherpa-onnx-whisper-" + name
69
+ encoder = _get_nn_model_filename(
70
+ repo_id=full_repo_id,
71
+ filename=f"{name}-encoder.int8.ort",
72
+ subfolder=".",
73
+ )
74
+
75
+ decoder = _get_nn_model_filename(
76
+ repo_id=full_repo_id,
77
+ filename=f"{name}-decoder.int8.ort",
78
+ subfolder=".",
79
+ )
80
+
81
+ tokens = _get_token_filename(
82
+ repo_id=full_repo_id, subfolder=".", filename=f"{name}-tokens.txt"
83
+ )
84
+
85
+ recognizer = sherpa_onnx.OfflineRecognizer.from_whisper(
86
+ encoder=encoder,
87
+ decoder=decoder,
88
+ tokens=tokens,
89
+ num_threads=2,
90
+ )
91
+
92
+ return recognizer
93
+
94
+
95
+ @lru_cache(maxsize=10)
96
+ def _get_paraformer_zh_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer:
97
+ assert repo_id in [
98
+ "csukuangfj/sherpa-onnx-paraformer-zh-2023-03-28",
99
+ ], repo_id
100
+
101
+ nn_model = _get_nn_model_filename(
102
+ repo_id=repo_id,
103
+ filename="model.int8.onnx",
104
+ subfolder=".",
105
+ )
106
+
107
+ tokens = _get_token_filename(repo_id=repo_id, subfolder=".")
108
+
109
+ recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
110
+ paraformer=nn_model,
111
+ tokens=tokens,
112
+ num_threads=2,
113
+ sample_rate=sample_rate,
114
+ feature_dim=80,
115
+ decoding_method="greedy_search",
116
+ debug=False,
117
+ )
118
+
119
+ return recognizer
120
+
121
+
122
+ @lru_cache(maxsize=10)
123
+ def _get_russian_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer:
124
+ assert repo_id in (
125
+ "alphacep/vosk-model-ru",
126
+ "alphacep/vosk-model-small-ru",
127
+ ), repo_id
128
+
129
+ if repo_id == "alphacep/vosk-model-ru":
130
+ model_dir = "am-onnx"
131
+ elif repo_id == "alphacep/vosk-model-small-ru":
132
+ model_dir = "am"
133
+
134
+ encoder_model = _get_nn_model_filename(
135
+ repo_id=repo_id,
136
+ filename="encoder.onnx",
137
+ subfolder=model_dir,
138
+ )
139
+
140
+ decoder_model = _get_nn_model_filename(
141
+ repo_id=repo_id,
142
+ filename="decoder.onnx",
143
+ subfolder=model_dir,
144
+ )
145
+
146
+ joiner_model = _get_nn_model_filename(
147
+ repo_id=repo_id,
148
+ filename="joiner.onnx",
149
+ subfolder=model_dir,
150
+ )
151
+
152
+ tokens = _get_token_filename(repo_id=repo_id, subfolder="lang")
153
+
154
+ recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
155
+ tokens=tokens,
156
+ encoder=encoder_model,
157
+ decoder=decoder_model,
158
+ joiner=joiner_model,
159
+ num_threads=2,
160
+ sample_rate=16000,
161
+ feature_dim=80,
162
+ decoding_method="greedy_search",
163
+ )
164
+
165
+ return recognizer
166
+
167
+
168
  english_models = {
169
  "whisper-tiny.en": _get_whisper_model,
170
  "whisper-base.en": _get_whisper_model,