csukuangfj commited on
Commit
21fcf42
1 Parent(s): 8ed6ca2

add punctuations

Browse files
Files changed (3) hide show
  1. app.py +17 -4
  2. decode.py +3 -0
  3. model.py +15 -0
app.py CHANGED
@@ -26,7 +26,7 @@ from pathlib import Path
26
  import gradio as gr
27
 
28
  from decode import decode
29
- from model import get_pretrained_model, get_vad, language_to_models
30
 
31
  title = "# Next-gen Kaldi: Generate subtitles for videos"
32
 
@@ -89,6 +89,7 @@ def show_file_info(in_filename: str):
89
  def process_uploaded_video_file(
90
  language: str,
91
  repo_id: str,
 
92
  in_filename: str,
93
  ):
94
  if in_filename is None or in_filename == "":
@@ -105,13 +106,14 @@ def process_uploaded_video_file(
105
 
106
  logging.info(f"Processing uploaded file: {in_filename}")
107
 
108
- ans = process(language, repo_id, in_filename)
109
  return (in_filename, ans[0]), ans[0], ans[1], ans[2]
110
 
111
 
112
  def process_uploaded_audio_file(
113
  language: str,
114
  repo_id: str,
 
115
  in_filename: str,
116
  ):
117
  if in_filename is None or in_filename == "":
@@ -131,11 +133,15 @@ def process_uploaded_audio_file(
131
  return process(language, repo_id, in_filename)
132
 
133
 
134
- def process(language: str, repo_id: str, in_filename: str):
135
  recognizer = get_pretrained_model(repo_id)
136
  vad = get_vad()
 
 
 
 
137
 
138
- result = decode(recognizer, vad, in_filename)
139
  logging.info(result)
140
 
141
  srt_filename = Path(in_filename).with_suffix(".srt")
@@ -176,6 +182,11 @@ with demo:
176
  inputs=language_radio,
177
  outputs=model_dropdown,
178
  )
 
 
 
 
 
179
 
180
  with gr.Tabs():
181
  with gr.TabItem("Upload video from disk"):
@@ -218,6 +229,7 @@ with demo:
218
  inputs=[
219
  language_radio,
220
  model_dropdown,
 
221
  uploaded_video_file,
222
  ],
223
  outputs=[
@@ -233,6 +245,7 @@ with demo:
233
  inputs=[
234
  language_radio,
235
  model_dropdown,
 
236
  uploaded_audio_file,
237
  ],
238
  outputs=[
 
26
  import gradio as gr
27
 
28
  from decode import decode
29
+ from model import get_pretrained_model, get_vad, language_to_models, get_punct_model
30
 
31
  title = "# Next-gen Kaldi: Generate subtitles for videos"
32
 
 
89
  def process_uploaded_video_file(
90
  language: str,
91
  repo_id: str,
92
+ add_punctuation: str,
93
  in_filename: str,
94
  ):
95
  if in_filename is None or in_filename == "":
 
106
 
107
  logging.info(f"Processing uploaded file: {in_filename}")
108
 
109
+ ans = process(language, repo_id, add_punctuation, in_filename)
110
  return (in_filename, ans[0]), ans[0], ans[1], ans[2]
111
 
112
 
113
  def process_uploaded_audio_file(
114
  language: str,
115
  repo_id: str,
116
+ add_punctuation: str,
117
  in_filename: str,
118
  ):
119
  if in_filename is None or in_filename == "":
 
133
  return process(language, repo_id, in_filename)
134
 
135
 
136
+ def process(language: str, repo_id: str, add_punctuation: str, in_filename: str):
137
  recognizer = get_pretrained_model(repo_id)
138
  vad = get_vad()
139
+ if add_punctuation == "Yes":
140
+ punct = get_punct_model()
141
+ else:
142
+ punct = None
143
 
144
+ result = decode(recognizer, vad, punct, in_filename)
145
  logging.info(result)
146
 
147
  srt_filename = Path(in_filename).with_suffix(".srt")
 
182
  inputs=language_radio,
183
  outputs=model_dropdown,
184
  )
185
+ punct_radio = gr.Radio(
186
+ label="Whether to add punctuation",
187
+ choices=["Yes", "No"],
188
+ value="Yes",
189
+ )
190
 
191
  with gr.Tabs():
192
  with gr.TabItem("Upload video from disk"):
 
229
  inputs=[
230
  language_radio,
231
  model_dropdown,
232
+ punct_radio,
233
  uploaded_video_file,
234
  ],
235
  outputs=[
 
245
  inputs=[
246
  language_radio,
247
  model_dropdown,
248
+ punct_radio,
249
  uploaded_audio_file,
250
  ],
251
  outputs=[
decode.py CHANGED
@@ -48,6 +48,7 @@ class Segment:
48
  def decode(
49
  recognizer: sherpa_onnx.OfflineRecognizer,
50
  vad: sherpa_onnx.VoiceActivityDetector,
 
51
  filename: str,
52
  ) -> str:
53
  ffmpeg_cmd = [
@@ -114,6 +115,8 @@ def decode(
114
 
115
  for seg, stream in zip(segments, streams):
116
  seg.text = stream.result.text.strip()
 
 
117
  segment_list.append(seg)
118
 
119
  return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1))
 
48
  def decode(
49
  recognizer: sherpa_onnx.OfflineRecognizer,
50
  vad: sherpa_onnx.VoiceActivityDetector,
51
+ punct: Optional[sherpa_onnx.OfflinePunctuation],
52
  filename: str,
53
  ) -> str:
54
  ffmpeg_cmd = [
 
115
 
116
  for seg, stream in zip(segments, streams):
117
  seg.text = stream.result.text.strip()
118
+ if punct is not None:
119
+ seg.text = punct.add_punctuation(seg.text)
120
  segment_list.append(seg)
121
 
122
  return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1))
model.py CHANGED
@@ -168,6 +168,21 @@ def _get_russian_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognize
168
  return recognizer
169
 
170
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  @lru_cache(maxsize=2)
172
  def get_vad() -> sherpa_onnx.VoiceActivityDetector:
173
  vad_model = _get_nn_model_filename(
 
168
  return recognizer
169
 
170
 
171
+ @lru_cache(maxsize=2)
172
+ def get_punct_model() -> sherpa_onnx.OfflinePunctuation:
173
+ model = _get_nn_model_filename(
174
+ repo_id="csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12",
175
+ filename="model.onnx",
176
+ subfolder=".",
177
+ )
178
+ config = sherpa_onnx.OfflinePunctuationConfig(
179
+ model=sherpa_onnx.OfflinePunctuationModelConfig(ct_transformer=model),
180
+ )
181
+
182
+ punct = sherpa_onnx.OfflinePunctuation(config)
183
+ return punct
184
+
185
+
186
  @lru_cache(maxsize=2)
187
  def get_vad() -> sherpa_onnx.VoiceActivityDetector:
188
  vad_model = _get_nn_model_filename(