csukuangfj commited on
Commit
06b4245
1 Parent(s): 994c238

add chinese models

Browse files
Files changed (1) hide show
  1. model.py +47 -1
model.py CHANGED
@@ -192,7 +192,9 @@ def get_vad() -> sherpa_onnx.VoiceActivityDetector:
192
 
193
  @lru_cache(maxsize=10)
194
  def get_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer:
195
- if repo_id in english_models:
 
 
196
  return english_models[repo_id](repo_id)
197
  elif repo_id in chinese_english_mixed_models:
198
  return chinese_english_mixed_models[repo_id](repo_id)
@@ -202,6 +204,49 @@ def get_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer:
202
  raise ValueError(f"Unsupported repo_id: {repo_id}")
203
 
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  english_models = {
206
  "whisper-tiny.en": _get_whisper_model,
207
  "whisper-base.en": _get_whisper_model,
@@ -218,6 +263,7 @@ russian_models = {
218
  }
219
 
220
  language_to_models = {
 
221
  "English": list(english_models.keys()),
222
  "Chinese+English": list(chinese_english_mixed_models.keys()),
223
  "Russian": list(russian_models.keys()),
 
192
 
193
  @lru_cache(maxsize=10)
194
  def get_pretrained_model(repo_id: str) -> sherpa_onnx.OfflineRecognizer:
195
+ if repo_id in chinese_models:
196
+ return chinese_models[repo_id](repo_id)
197
+ elif repo_id in english_models:
198
  return english_models[repo_id](repo_id)
199
  elif repo_id in chinese_english_mixed_models:
200
  return chinese_english_mixed_models[repo_id](repo_id)
 
204
  raise ValueError(f"Unsupported repo_id: {repo_id}")
205
 
206
 
207
+ def _get_wenetspeech_pre_trained_model(repo_id):
208
+ assert repo_id in (
209
+ "csukuangfj/sherpa-onnx-conformer-zh-stateless2-2023-05-23",
210
+ ), repo_id
211
+
212
+ encoder_model = _get_nn_model_filename(
213
+ repo_id=repo_id,
214
+ filename="encoder-epoch-99-avg-1.onnx",
215
+ subfolder=".",
216
+ )
217
+
218
+ decoder_model = _get_nn_model_filename(
219
+ repo_id=repo_id,
220
+ filename="decoder-epoch-99-avg-1.onnx",
221
+ subfolder=".",
222
+ )
223
+
224
+ joiner_model = _get_nn_model_filename(
225
+ repo_id=repo_id,
226
+ filename="joiner-epoch-99-avg-1.onnx",
227
+ subfolder=".",
228
+ )
229
+
230
+ tokens = _get_token_filename(repo_id=repo_id, subfolder=".")
231
+
232
+ recognizer = sherpa_onnx.OfflineRecognizer.from_transducer(
233
+ tokens=tokens,
234
+ encoder=encoder_model,
235
+ decoder=decoder_model,
236
+ joiner=joiner_model,
237
+ num_threads=2,
238
+ sample_rate=16000,
239
+ feature_dim=80,
240
+ decoding_method="greedy_search",
241
+ )
242
+
243
+ return recognizer
244
+
245
+
246
+ chinese_models = {
247
+ "csukuangfj/sherpa-onnx-conformer-zh-stateless2-2023-05-23": _get_wenetspeech_pre_trained_model, # noqa
248
+ }
249
+
250
  english_models = {
251
  "whisper-tiny.en": _get_whisper_model,
252
  "whisper-base.en": _get_whisper_model,
 
263
  }
264
 
265
  language_to_models = {
266
+ "Chinese": list(chinese_models),
267
  "English": list(english_models.keys()),
268
  "Chinese+English": list(chinese_english_mixed_models.keys()),
269
  "Russian": list(russian_models.keys()),