bofenghuang commited on
Commit
8ecd0fd
1 Parent(s): 920af5f
app.py CHANGED
@@ -1 +1 @@
1
- run_demo_low_api_openai.py
 
1
+ run_demo_ct2.py
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
- git+https://github.com/huggingface/transformers
2
  git+https://github.com/openai/whisper.git
 
3
  nltk
4
  pandas
5
  psutil
 
1
+ git+https://github.com/huggingface/transformers.git
2
  git+https://github.com/openai/whisper.git
3
+ git+https://github.com/guillaumekln/faster-whisper.git
4
  nltk
5
  pandas
6
  psutil
run_demo_ct2.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #! /usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2022 Bofeng Huang
4
+
5
+ import datetime
6
+ import logging
7
+ import os
8
+ import re
9
+ import warnings
10
+
11
+ import gradio as gr
12
+ import pandas as pd
13
+ import psutil
14
+ import pytube as pt
15
+ import torch
16
+ # import whisper
17
+ from faster_whisper import WhisperModel
18
+ from huggingface_hub import hf_hub_download, snapshot_download
19
+ from transformers.utils.logging import disable_progress_bar
20
+
21
+ import nltk
22
+ nltk.download("punkt")
23
+
24
+ from nltk.tokenize import sent_tokenize
25
+
26
+ warnings.filterwarnings("ignore")
27
+ disable_progress_bar()
28
+
29
+ # DEFAULT_MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-german"
30
+ DEFAULT_MODEL_NAME = "bofenghuang/whisper-large-v2-cv11-german-ct2"
31
+ # CHECKPOINT_FILENAME = "checkpoint_openai.pt"
32
+
33
+ GEN_KWARGS = {
34
+ "task": "transcribe",
35
+ "language": "de",
36
+ # "without_timestamps": True,
37
+ # decode options
38
+ # "beam_size": 5,
39
+ # "patience": 2,
40
+ # disable fallback
41
+ # "compression_ratio_threshold": None,
42
+ # "logprob_threshold": None,
43
+ # vad threshold
44
+ # "no_speech_threshold": None,
45
+ }
46
+
47
+ logging.basicConfig(
48
+ format="%(asctime)s [%(levelname)s] [%(name)s] %(message)s",
49
+ datefmt="%Y-%m-%dT%H:%M:%SZ",
50
+ )
51
+ logger = logging.getLogger(__name__)
52
+ logger.setLevel(logging.DEBUG)
53
+
54
+ # device = 0 if torch.cuda.is_available() else "cpu"
55
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
56
+ device = "cuda" if torch.cuda.is_available() else "cpu"
57
+ logger.info(f"Model will be loaded on device `{device}`")
58
+
59
+ cached_models = {}
60
+
61
+
62
+ def format_timestamp(seconds):
63
+ return str(datetime.timedelta(seconds=round(seconds)))
64
+
65
+
66
+ def _return_yt_html_embed(yt_url):
67
+ video_id = yt_url.split("?v=")[-1]
68
+ HTML_str = (
69
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>' " </center>"
70
+ )
71
+ return HTML_str
72
+
73
+
74
+ def download_audio_from_youtube(yt_url, downloaded_filename="audio.wav"):
75
+ yt = pt.YouTube(yt_url)
76
+ stream = yt.streams.filter(only_audio=True)[0]
77
+ # stream.download(filename="audio.mp3")
78
+ stream.download(filename=downloaded_filename)
79
+ return downloaded_filename
80
+
81
+
82
+ def download_video_from_youtube(yt_url, downloaded_filename="video.mp4"):
83
+ yt = pt.YouTube(yt_url)
84
+ stream = yt.streams.filter(progressive=True, file_extension="mp4").order_by("resolution").desc().first()
85
+ stream.download(filename=downloaded_filename)
86
+ logger.info(f"Download YouTube video from {yt_url}")
87
+ return downloaded_filename
88
+
89
+
90
+ def _print_memory_info():
91
+ memory = psutil.virtual_memory()
92
+ logger.info(
93
+ f"Memory info - Free: {memory.available / (1024 ** 3):.2f} Gb, used: {memory.percent}%, total: {memory.total / (1024 ** 3):.2f} Gb"
94
+ )
95
+
96
+
97
+ def _print_cuda_memory_info():
98
+ used_mem, tot_mem = torch.cuda.mem_get_info()
99
+ logger.info(
100
+ f"CUDA memory info - Free: {used_mem / 1024 ** 3:.2f} Gb, used: {(tot_mem - used_mem) / 1024 ** 3:.2f} Gb, total: {tot_mem / 1024 ** 3:.2f} Gb"
101
+ )
102
+
103
+
104
+ def print_memory_info():
105
+ _print_memory_info()
106
+ _print_cuda_memory_info()
107
+
108
+
109
+ def maybe_load_cached_pipeline(model_name):
110
+ model = cached_models.get(model_name)
111
+ if model is None:
112
+ # downloaded_model_path = hf_hub_download(repo_id=model_name, filename=CHECKPOINT_FILENAME)
113
+ downloaded_model_path = snapshot_download(repo_id=model_name)
114
+
115
+ # model = whisper.load_model(downloaded_model_path, device=device)
116
+ model = WhisperModel(downloaded_model_path, device=device, compute_type="float16")
117
+ logger.info(f"`{model_name}` has been loaded on device `{device}`")
118
+
119
+ print_memory_info()
120
+
121
+ cached_models[model_name] = model
122
+ return model
123
+
124
+
125
+ def infer(model, filename, with_timestamps, return_df=False):
126
+ if with_timestamps:
127
+ # model_outputs = model.transcribe(filename, **GEN_KWARGS)
128
+ model_outputs, _ = model.transcribe(filename, **GEN_KWARGS)
129
+ model_outputs = [segment._asdict() for segment in model_outputs]
130
+ if return_df:
131
+ # model_outputs_df = pd.DataFrame(model_outputs["segments"])
132
+ model_outputs_df = pd.DataFrame(model_outputs)
133
+ # print(model_outputs)
134
+ # print(model_outputs_df)
135
+ # print(model_outputs_df.info(verbose=True))
136
+ model_outputs_df = model_outputs_df[["start", "end", "text"]]
137
+ model_outputs_df["start"] = model_outputs_df["start"].map(format_timestamp)
138
+ model_outputs_df["end"] = model_outputs_df["end"].map(format_timestamp)
139
+ model_outputs_df["text"] = model_outputs_df["text"].str.strip()
140
+ return model_outputs_df
141
+ else:
142
+ return "\n\n".join(
143
+ [
144
+ f'Segment {segment["id"]+1} from {segment["start"]:.2f}s to {segment["end"]:.2f}s:\n{segment["text"].strip()}'
145
+ # for segment in model_outputs["segments"]
146
+ for segment in model_outputs
147
+ ]
148
+ )
149
+ else:
150
+ # text = model.transcribe(filename, without_timestamps=True, **GEN_KWARGS)["text"]
151
+ model_outputs, _ = model.transcribe(filename, without_timestamps=True, **GEN_KWARGS)
152
+ text = " ".join([segment.text for segment in model_outputs])
153
+ if return_df:
154
+ return pd.DataFrame({"text": sent_tokenize(text)})
155
+ else:
156
+ return text
157
+
158
+
159
+ def transcribe(microphone, file_upload, with_timestamps, model_name=DEFAULT_MODEL_NAME):
160
+ warn_output = ""
161
+ if (microphone is not None) and (file_upload is not None):
162
+ warn_output = (
163
+ "WARNING: You've uploaded an audio file and used the microphone. "
164
+ "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n"
165
+ )
166
+
167
+ elif (microphone is None) and (file_upload is None):
168
+ return "ERROR: You have to either use the microphone or upload an audio file"
169
+
170
+ file = microphone if microphone is not None else file_upload
171
+
172
+ model = maybe_load_cached_pipeline(model_name)
173
+ # text = model.transcribe(file, **GEN_KWARGS)["text"]
174
+ # text = infer(model, file, with_timestamps)
175
+ text = infer(model, file, with_timestamps, return_df=True)
176
+
177
+ logger.info(f'Transcription by `{model_name}`:\n{text.to_json(orient="index", force_ascii=False, indent=2)}\n')
178
+
179
+ # return warn_output + text
180
+ return text
181
+
182
+
183
+ def yt_transcribe(yt_url, with_timestamps, model_name=DEFAULT_MODEL_NAME):
184
+ # html_embed_str = _return_yt_html_embed(yt_url)
185
+ audio_file_path = download_audio_from_youtube(yt_url)
186
+
187
+ model = maybe_load_cached_pipeline(model_name)
188
+ # text = model.transcribe("audio.mp3", **GEN_KWARGS)["text"]
189
+ # text = infer(model, audio_file_path, with_timestamps)
190
+ text = infer(model, audio_file_path, with_timestamps, return_df=True)
191
+
192
+ logger.info(f'Transcription by `{model_name}` of "{yt_url}":\n{text.to_json(orient="index", force_ascii=False, indent=2)}\n')
193
+
194
+ # return html_embed_str, text
195
+ return text
196
+
197
+
198
+ def video_transcribe(video_file_path, with_timestamps, model_name=DEFAULT_MODEL_NAME):
199
+ if video_file_path is None:
200
+ raise ValueError("Failed to transcribe video as no video_file_path has been defined")
201
+
202
+ audio_file_path = re.sub(r"\.mp4$", ".wav", video_file_path)
203
+ os.system(f'ffmpeg -hide_banner -loglevel error -y -i "{video_file_path}" -ar 16000 -ac 1 -c:a pcm_s16le "{audio_file_path}"')
204
+
205
+ model = maybe_load_cached_pipeline(model_name)
206
+ # text = model.transcribe("audio.mp3", **GEN_KWARGS)["text"]
207
+ text = infer(model, audio_file_path, with_timestamps, return_df=True)
208
+
209
+ logger.info(f'Transcription by `{model_name}`:\n{text.to_json(orient="index", force_ascii=False, indent=2)}\n')
210
+
211
+ return text
212
+
213
+
214
+ # load default model
215
+ maybe_load_cached_pipeline(DEFAULT_MODEL_NAME)
216
+
217
+ # default_text_output_df = pd.DataFrame(columns=["start", "end", "text"])
218
+ default_text_output_df = pd.DataFrame(columns=["text"])
219
+
220
+ with gr.Blocks() as demo:
221
+
222
+ with gr.Tab("Transcribe Audio"):
223
+ gr.Markdown(
224
+ f"""
225
+ <div>
226
+ <h1 style='text-align: center'>Whisper German Demo: Transcribe Audio</h1>
227
+ </div>
228
+ Transcribe long-form microphone or audio inputs!
229
+
230
+ Demo uses the fine-tuned checkpoint: <a href='https://huggingface.co/{DEFAULT_MODEL_NAME}' target='_blank'><b>{DEFAULT_MODEL_NAME}</b></a> to transcribe audio files of arbitrary length.
231
+
232
+ Efficient inference is supported by [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2).
233
+ """
234
+ )
235
+
236
+ microphone_input = gr.inputs.Audio(source="microphone", type="filepath", label="Record", optional=True)
237
+ upload_input = gr.inputs.Audio(source="upload", type="filepath", label="Upload File", optional=True)
238
+ with_timestamps_input = gr.Checkbox(label="With timestamps?")
239
+
240
+ microphone_transcribe_btn = gr.Button("Transcribe Audio")
241
+
242
+ # gr.Markdown('''
243
+ # Here you will get generated transcrit.
244
+ # ''')
245
+
246
+ # microphone_text_output = gr.outputs.Textbox(label="Transcription")
247
+ text_output_df2 = gr.DataFrame(
248
+ value=default_text_output_df,
249
+ label="Transcription",
250
+ row_count=(0, "dynamic"),
251
+ max_rows=10,
252
+ wrap=True,
253
+ overflow_row_behaviour="paginate",
254
+ )
255
+
256
+ microphone_transcribe_btn.click(
257
+ transcribe, inputs=[microphone_input, upload_input, with_timestamps_input], outputs=text_output_df2
258
+ )
259
+
260
+ # with gr.Tab("Transcribe YouTube"):
261
+ # gr.Markdown(
262
+ # f"""
263
+ # <div>
264
+ # <h1 style='text-align: center'>Whisper German Demo: Transcribe YouTube</h1>
265
+ # </div>
266
+ # Transcribe long-form YouTube videos!
267
+
268
+ # Demo uses the fine-tuned checkpoint: <a href='https://huggingface.co/{DEFAULT_MODEL_NAME}' target='_blank'><b>{DEFAULT_MODEL_NAME}</b></a> to transcribe video files of arbitrary length.
269
+ # """
270
+ # )
271
+
272
+ # yt_link_input2 = gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")
273
+ # with_timestamps_input2 = gr.Checkbox(label="With timestamps?", value=True)
274
+
275
+ # yt_transcribe_btn = gr.Button("Transcribe YouTube")
276
+
277
+ # # yt_text_output = gr.outputs.Textbox(label="Transcription")
278
+ # text_output_df3 = gr.DataFrame(
279
+ # value=default_text_output_df,
280
+ # label="Transcription",
281
+ # row_count=(0, "dynamic"),
282
+ # max_rows=10,
283
+ # wrap=True,
284
+ # overflow_row_behaviour="paginate",
285
+ # )
286
+ # # yt_html_output = gr.outputs.HTML(label="YouTube Page")
287
+
288
+ # yt_transcribe_btn.click(yt_transcribe, inputs=[yt_link_input2, with_timestamps_input2], outputs=[text_output_df3])
289
+
290
+ with gr.Tab("Transcribe Video"):
291
+ gr.Markdown(
292
+ f"""
293
+ <div>
294
+ <h1 style='text-align: center'>Whisper German Demo: Transcribe Video</h1>
295
+ </div>
296
+ Transcribe long-form YouTube videos or uploaded video inputs!
297
+
298
+ Demo uses the fine-tuned checkpoint: <a href='https://huggingface.co/{DEFAULT_MODEL_NAME}' target='_blank'><b>{DEFAULT_MODEL_NAME}</b></a> to transcribe video files of arbitrary length.
299
+
300
+ Efficient inference is supported by [faster-whisper](https://github.com/guillaumekln/faster-whisper) and [CTranslate2](https://github.com/OpenNMT/CTranslate2).
301
+ """
302
+ )
303
+
304
+ yt_link_input = gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")
305
+ download_youtube_btn = gr.Button("Download Youtube video")
306
+ downloaded_video_output = gr.Video(label="Video file", mirror_webcam=False)
307
+ download_youtube_btn.click(download_video_from_youtube, inputs=[yt_link_input], outputs=[downloaded_video_output])
308
+
309
+ with_timestamps_input3 = gr.Checkbox(label="With timestamps?", value=True)
310
+ video_transcribe_btn = gr.Button("Transcribe video")
311
+ text_output_df = gr.DataFrame(
312
+ value=default_text_output_df,
313
+ label="Transcription",
314
+ row_count=(0, "dynamic"),
315
+ max_rows=10,
316
+ wrap=True,
317
+ overflow_row_behaviour="paginate",
318
+ )
319
+
320
+ video_transcribe_btn.click(video_transcribe, inputs=[downloaded_video_output, with_timestamps_input3], outputs=[text_output_df])
321
+
322
+ # demo.launch(server_name="0.0.0.0", debug=True)
323
+ # demo.launch(server_name="0.0.0.0", debug=True, share=True)
324
+ demo.launch(enable_queue=True)
run_demo.py → run_demo_hf.py RENAMED
File without changes
run_demo_multi_models.py → run_demo_hf_multiple_models.py RENAMED
File without changes
run_demo_low_api_openai.py → run_demo_openai_layout.py RENAMED
File without changes