csukuangfj
commited on
Commit
•
21fcf42
1
Parent(s):
8ed6ca2
add punctuations
Browse files
app.py
CHANGED
@@ -26,7 +26,7 @@ from pathlib import Path
|
|
26 |
import gradio as gr
|
27 |
|
28 |
from decode import decode
|
29 |
-
from model import get_pretrained_model, get_vad, language_to_models
|
30 |
|
31 |
title = "# Next-gen Kaldi: Generate subtitles for videos"
|
32 |
|
@@ -89,6 +89,7 @@ def show_file_info(in_filename: str):
|
|
89 |
def process_uploaded_video_file(
|
90 |
language: str,
|
91 |
repo_id: str,
|
|
|
92 |
in_filename: str,
|
93 |
):
|
94 |
if in_filename is None or in_filename == "":
|
@@ -105,13 +106,14 @@ def process_uploaded_video_file(
|
|
105 |
|
106 |
logging.info(f"Processing uploaded file: {in_filename}")
|
107 |
|
108 |
-
ans = process(language, repo_id, in_filename)
|
109 |
return (in_filename, ans[0]), ans[0], ans[1], ans[2]
|
110 |
|
111 |
|
112 |
def process_uploaded_audio_file(
|
113 |
language: str,
|
114 |
repo_id: str,
|
|
|
115 |
in_filename: str,
|
116 |
):
|
117 |
if in_filename is None or in_filename == "":
|
@@ -131,11 +133,15 @@ def process_uploaded_audio_file(
|
|
131 |
return process(language, repo_id, in_filename)
|
132 |
|
133 |
|
134 |
-
def process(language: str, repo_id: str, in_filename: str):
|
135 |
recognizer = get_pretrained_model(repo_id)
|
136 |
vad = get_vad()
|
|
|
|
|
|
|
|
|
137 |
|
138 |
-
result = decode(recognizer, vad, in_filename)
|
139 |
logging.info(result)
|
140 |
|
141 |
srt_filename = Path(in_filename).with_suffix(".srt")
|
@@ -176,6 +182,11 @@ with demo:
|
|
176 |
inputs=language_radio,
|
177 |
outputs=model_dropdown,
|
178 |
)
|
|
|
|
|
|
|
|
|
|
|
179 |
|
180 |
with gr.Tabs():
|
181 |
with gr.TabItem("Upload video from disk"):
|
@@ -218,6 +229,7 @@ with demo:
|
|
218 |
inputs=[
|
219 |
language_radio,
|
220 |
model_dropdown,
|
|
|
221 |
uploaded_video_file,
|
222 |
],
|
223 |
outputs=[
|
@@ -233,6 +245,7 @@ with demo:
|
|
233 |
inputs=[
|
234 |
language_radio,
|
235 |
model_dropdown,
|
|
|
236 |
uploaded_audio_file,
|
237 |
],
|
238 |
outputs=[
|
|
|
26 |
import gradio as gr
|
27 |
|
28 |
from decode import decode
|
29 |
+
from model import get_pretrained_model, get_vad, language_to_models, get_punct_model
|
30 |
|
31 |
title = "# Next-gen Kaldi: Generate subtitles for videos"
|
32 |
|
|
|
89 |
def process_uploaded_video_file(
|
90 |
language: str,
|
91 |
repo_id: str,
|
92 |
+
add_punctuation: str,
|
93 |
in_filename: str,
|
94 |
):
|
95 |
if in_filename is None or in_filename == "":
|
|
|
106 |
|
107 |
logging.info(f"Processing uploaded file: {in_filename}")
|
108 |
|
109 |
+
ans = process(language, repo_id, add_punctuation, in_filename)
|
110 |
return (in_filename, ans[0]), ans[0], ans[1], ans[2]
|
111 |
|
112 |
|
113 |
def process_uploaded_audio_file(
|
114 |
language: str,
|
115 |
repo_id: str,
|
116 |
+
add_punctuation: str,
|
117 |
in_filename: str,
|
118 |
):
|
119 |
if in_filename is None or in_filename == "":
|
|
|
133 |
return process(language, repo_id, in_filename)
|
134 |
|
135 |
|
136 |
+
def process(language: str, repo_id: str, add_punctuation: str, in_filename: str):
|
137 |
recognizer = get_pretrained_model(repo_id)
|
138 |
vad = get_vad()
|
139 |
+
if add_punctuation == "Yes":
|
140 |
+
punct = get_punct_model()
|
141 |
+
else:
|
142 |
+
punct = None
|
143 |
|
144 |
+
result = decode(recognizer, vad, punct, in_filename)
|
145 |
logging.info(result)
|
146 |
|
147 |
srt_filename = Path(in_filename).with_suffix(".srt")
|
|
|
182 |
inputs=language_radio,
|
183 |
outputs=model_dropdown,
|
184 |
)
|
185 |
+
punct_radio = gr.Radio(
|
186 |
+
label="Whether to add punctuation",
|
187 |
+
choices=["Yes", "No"],
|
188 |
+
value="Yes",
|
189 |
+
)
|
190 |
|
191 |
with gr.Tabs():
|
192 |
with gr.TabItem("Upload video from disk"):
|
|
|
229 |
inputs=[
|
230 |
language_radio,
|
231 |
model_dropdown,
|
232 |
+
punct_radio,
|
233 |
uploaded_video_file,
|
234 |
],
|
235 |
outputs=[
|
|
|
245 |
inputs=[
|
246 |
language_radio,
|
247 |
model_dropdown,
|
248 |
+
punct_radio,
|
249 |
uploaded_audio_file,
|
250 |
],
|
251 |
outputs=[
|
decode.py
CHANGED
@@ -48,6 +48,7 @@ class Segment:
|
|
48 |
def decode(
|
49 |
recognizer: sherpa_onnx.OfflineRecognizer,
|
50 |
vad: sherpa_onnx.VoiceActivityDetector,
|
|
|
51 |
filename: str,
|
52 |
) -> str:
|
53 |
ffmpeg_cmd = [
|
@@ -114,6 +115,8 @@ def decode(
|
|
114 |
|
115 |
for seg, stream in zip(segments, streams):
|
116 |
seg.text = stream.result.text.strip()
|
|
|
|
|
117 |
segment_list.append(seg)
|
118 |
|
119 |
return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1))
|
|
|
48 |
def decode(
|
49 |
recognizer: sherpa_onnx.OfflineRecognizer,
|
50 |
vad: sherpa_onnx.VoiceActivityDetector,
|
51 |
+
punct: Optional[sherpa_onnx.OfflinePunctuation],
|
52 |
filename: str,
|
53 |
) -> str:
|
54 |
ffmpeg_cmd = [
|
|
|
115 |
|
116 |
for seg, stream in zip(segments, streams):
|
117 |
seg.text = stream.result.text.strip()
|
118 |
+
if punct is not None:
|
119 |
+
seg.text = punct.add_punctuation(seg.text)
|
120 |
segment_list.append(seg)
|
121 |
|
122 |
return "\n\n".join(f"{i}\n{seg}" for i, seg in enumerate(segment_list, 1))
|
model.py
CHANGED
@@ -168,6 +168,21 @@ def _get_russian_pre_trained_model(repo_id: str) -> sherpa_onnx.OfflineRecognize
|
|
168 |
return recognizer
|
169 |
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
171 |
@lru_cache(maxsize=2)
|
172 |
def get_vad() -> sherpa_onnx.VoiceActivityDetector:
|
173 |
vad_model = _get_nn_model_filename(
|
|
|
168 |
return recognizer
|
169 |
|
170 |
|
171 |
+
@lru_cache(maxsize=2)
|
172 |
+
def get_punct_model() -> sherpa_onnx.OfflinePunctuation:
|
173 |
+
model = _get_nn_model_filename(
|
174 |
+
repo_id="csukuangfj/sherpa-onnx-punct-ct-transformer-zh-en-vocab272727-2024-04-12",
|
175 |
+
filename="model.onnx",
|
176 |
+
subfolder=".",
|
177 |
+
)
|
178 |
+
config = sherpa_onnx.OfflinePunctuationConfig(
|
179 |
+
model=sherpa_onnx.OfflinePunctuationModelConfig(ct_transformer=model),
|
180 |
+
)
|
181 |
+
|
182 |
+
punct = sherpa_onnx.OfflinePunctuation(config)
|
183 |
+
return punct
|
184 |
+
|
185 |
+
|
186 |
@lru_cache(maxsize=2)
|
187 |
def get_vad() -> sherpa_onnx.VoiceActivityDetector:
|
188 |
vad_model = _get_nn_model_filename(
|