csukuangfj commited on
Commit
09ae8c3
1 Parent(s): 1420134

add punctuations

Browse files
Files changed (3) hide show
  1. app.py +28 -1
  2. model.py +15 -0
  3. requirements.txt +1 -1
app.py CHANGED
@@ -32,7 +32,13 @@ import torch
32
  import torchaudio
33
 
34
  from examples import examples
35
- from model import decode, get_pretrained_model, language_to_models, sample_rate
 
 
 
 
 
 
36
 
37
  languages = list(language_to_models.keys())
38
 
@@ -65,6 +71,7 @@ def process_url(
65
  repo_id: str,
66
  decoding_method: str,
67
  num_active_paths: int,
 
68
  url: str,
69
  ):
70
  logging.info(f"Processing URL: {url}")
@@ -78,6 +85,7 @@ def process_url(
78
  repo_id=repo_id,
79
  decoding_method=decoding_method,
80
  num_active_paths=num_active_paths,
 
81
  )
82
  except Exception as e:
83
  logging.info(str(e))
@@ -89,6 +97,7 @@ def process_uploaded_file(
89
  repo_id: str,
90
  decoding_method: str,
91
  num_active_paths: int,
 
92
  in_filename: str,
93
  ):
94
  if in_filename is None or in_filename == "":
@@ -106,6 +115,7 @@ def process_uploaded_file(
106
  repo_id=repo_id,
107
  decoding_method=decoding_method,
108
  num_active_paths=num_active_paths,
 
109
  )
110
  except Exception as e:
111
  logging.info(str(e))
@@ -117,6 +127,7 @@ def process_microphone(
117
  repo_id: str,
118
  decoding_method: str,
119
  num_active_paths: int,
 
120
  in_filename: str,
121
  ):
122
  if in_filename is None or in_filename == "":
@@ -135,6 +146,7 @@ def process_microphone(
135
  repo_id=repo_id,
136
  decoding_method=decoding_method,
137
  num_active_paths=num_active_paths,
 
138
  )
139
  except Exception as e:
140
  logging.info(str(e))
@@ -147,6 +159,7 @@ def process(
147
  repo_id: str,
148
  decoding_method: str,
149
  num_active_paths: int,
 
150
  in_filename: str,
151
  ):
152
  logging.info(f"language: {language}")
@@ -170,6 +183,9 @@ def process(
170
  )
171
 
172
  text = decode(recognizer, filename)
 
 
 
173
 
174
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
175
  end = time.time()
@@ -277,6 +293,12 @@ with demo:
277
  label="Number of active paths for modified_beam_search",
278
  )
279
 
 
 
 
 
 
 
280
  with gr.Tabs():
281
  with gr.TabItem("Upload from disk"):
282
  uploaded_file = gr.Audio(
@@ -295,6 +317,7 @@ with demo:
295
  model_dropdown,
296
  decoding_method_radio,
297
  num_active_paths_slider,
 
298
  uploaded_file,
299
  ],
300
  outputs=[uploaded_output, uploaded_html_info],
@@ -319,6 +342,7 @@ with demo:
319
  model_dropdown,
320
  decoding_method_radio,
321
  num_active_paths_slider,
 
322
  microphone,
323
  ],
324
  outputs=[recorded_output, recorded_html_info],
@@ -344,6 +368,7 @@ with demo:
344
  model_dropdown,
345
  decoding_method_radio,
346
  num_active_paths_slider,
 
347
  uploaded_file,
348
  ],
349
  outputs=[uploaded_output, uploaded_html_info],
@@ -356,6 +381,7 @@ with demo:
356
  model_dropdown,
357
  decoding_method_radio,
358
  num_active_paths_slider,
 
359
  microphone,
360
  ],
361
  outputs=[recorded_output, recorded_html_info],
@@ -368,6 +394,7 @@ with demo:
368
  model_dropdown,
369
  decoding_method_radio,
370
  num_active_paths_slider,
 
371
  url_textbox,
372
  ],
373
  outputs=[url_output, url_html_info],
32
  import torchaudio
33
 
34
  from examples import examples
35
+ from model import (
36
+ decode,
37
+ get_pretrained_model,
38
+ get_punct_model,
39
+ language_to_models,
40
+ sample_rate,
41
+ )
42
 
43
  languages = list(language_to_models.keys())
44
 
71
  repo_id: str,
72
  decoding_method: str,
73
  num_active_paths: int,
74
+ add_punct: str,
75
  url: str,
76
  ):
77
  logging.info(f"Processing URL: {url}")
85
  repo_id=repo_id,
86
  decoding_method=decoding_method,
87
  num_active_paths=num_active_paths,
88
+ add_punct=add_punct,
89
  )
90
  except Exception as e:
91
  logging.info(str(e))
97
  repo_id: str,
98
  decoding_method: str,
99
  num_active_paths: int,
100
+ add_punct: str,
101
  in_filename: str,
102
  ):
103
  if in_filename is None or in_filename == "":
115
  repo_id=repo_id,
116
  decoding_method=decoding_method,
117
  num_active_paths=num_active_paths,
118
+ add_punct=add_punct,
119
  )
120
  except Exception as e:
121
  logging.info(str(e))
127
  repo_id: str,
128
  decoding_method: str,
129
  num_active_paths: int,
130
+ add_punct: str,
131
  in_filename: str,
132
  ):
133
  if in_filename is None or in_filename == "":
146
  repo_id=repo_id,
147
  decoding_method=decoding_method,
148
  num_active_paths=num_active_paths,
149
+ add_punct=add_punct,
150
  )
151
  except Exception as e:
152
  logging.info(str(e))
159
  repo_id: str,
160
  decoding_method: str,
161
  num_active_paths: int,
162
+ add_punct: str,
163
  in_filename: str,
164
  ):
165
  logging.info(f"language: {language}")
183
  )
184
 
185
  text = decode(recognizer, filename)
186
+ if add_punct == "Yes":
187
+ punct = get_punct_model()
188
+ text = punct.add_punctuation(text)
189
 
190
  date_time = now.strftime("%Y-%m-%d %H:%M:%S.%f")
191
  end = time.time()
293
  label="Number of active paths for modified_beam_search",
294
  )
295
 
296
+ punct_radio = gr.Radio(
297
+ label="Whether to add punctuation (Only for Chinese and English)",
298
+ choices=["Yes", "No"],
299
+ value="Yes",
300
+ )
301
+
302
  with gr.Tabs():
303
  with gr.TabItem("Upload from disk"):
304
  uploaded_file = gr.Audio(
317
  model_dropdown,
318
  decoding_method_radio,
319
  num_active_paths_slider,
320
+ punct_radio,
321
  uploaded_file,
322
  ],
323
  outputs=[uploaded_output, uploaded_html_info],
342
  model_dropdown,
343
  decoding_method_radio,
344
  num_active_paths_slider,
345
+ punct_radio,
346
  microphone,
347
  ],
348
  outputs=[recorded_output, recorded_html_info],
368
  model_dropdown,
369
  decoding_method_radio,
370
  num_active_paths_slider,
371
+ punct_radio,
372
  uploaded_file,
373
  ],
374
  outputs=[uploaded_output, uploaded_html_info],
381
  model_dropdown,
382
  decoding_method_radio,
383
  num_active_paths_slider,
384
+ punct_radio,
385
  microphone,
386
  ],
387
  outputs=[recorded_output, recorded_html_info],
394
  model_dropdown,
395
  decoding_method_radio,
396
  num_active_paths_slider,
397
+ punct_radio,
398
  url_textbox,
399
  ],
400
  outputs=[url_output, url_html_info],
model.py CHANGED
@@ -1182,6 +1182,21 @@ def _get_aishell_pre_trained_model(
1182
  return recognizer
1183
 
1184
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1185
  def _get_multi_zh_hans_pre_trained_model(
1186
  repo_id: str,
1187
  decoding_method: str,
1182
  return recognizer
1183
 
1184
 
1185
+ @lru_cache(maxsize=2)
1186
+ def get_punct_model() -> sherpa_onnx.OfflinePunctuation:
1187
+ model = _get_nn_model_filename(
1188
+ repo_id="csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12",
1189
+ filename="model.onnx",
1190
+ subfolder=".",
1191
+ )
1192
+ config = sherpa_onnx.OfflinePunctuationConfig(
1193
+ model=sherpa_onnx.OfflinePunctuationModelConfig(ct_transformer=model),
1194
+ )
1195
+
1196
+ punct = sherpa_onnx.OfflinePunctuation(config)
1197
+ return punct
1198
+
1199
+
1200
  def _get_multi_zh_hans_pre_trained_model(
1201
  repo_id: str,
1202
  decoding_method: str,
requirements.txt CHANGED
@@ -9,4 +9,4 @@ sentencepiece>=0.1.96
9
  numpy
10
 
11
  huggingface_hub
12
- sherpa-onnx
9
  numpy
10
 
11
  huggingface_hub
12
+ sherpa-onnx>=1.9.19