sanchit-gandhi HF staff commited on
Commit
2c5665b
1 Parent(s): d111436

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -16
app.py CHANGED
@@ -4,11 +4,9 @@ import time
4
 
5
  import gradio as gr
6
  import numpy as np
7
- import torch
8
  import yt_dlp as youtube_dl
9
  from gradio_client import Client
10
  from pyannote.audio import Pipeline
11
- from transformers.pipelines.audio_utils import ffmpeg_read
12
 
13
 
14
  YT_LENGTH_LIMIT_S = 36000 # limit to 1 hour YouTube files
@@ -189,11 +187,11 @@ def align(transcription, segments, group_by_speaker=True):
189
  return transcription
190
 
191
 
192
- def transcribe(audio_path, group_by_speaker=True):
193
  # run Whisper JAX asynchronously using Gradio client (endpoint)
194
  job = client.submit(
195
  audio_path,
196
- "transcribe",
197
  True,
198
  api_name="/predict_1",
199
  )
@@ -211,11 +209,11 @@ def transcribe(audio_path, group_by_speaker=True):
211
  return transcription
212
 
213
 
214
- def transcribe_yt(yt_url, group_by_speaker=True):
215
  # run Whisper JAX asynchronously using Gradio client (endpoint)
216
  job = client.submit(
217
  yt_url,
218
- "transcribe",
219
  True,
220
  api_name="/predict_2",
221
  )
@@ -224,17 +222,8 @@ def transcribe_yt(yt_url, group_by_speaker=True):
224
  with tempfile.TemporaryDirectory() as tmpdirname:
225
  filepath = os.path.join(tmpdirname, "video.mp4")
226
  download_yt_audio(yt_url, filepath)
 
227
 
228
- with open(filepath, "rb") as f:
229
- inputs = f.read()
230
-
231
- inputs = ffmpeg_read(inputs, SAMPLING_RATE)
232
- inputs = torch.from_numpy(inputs).float()
233
- inputs = inputs.unsqueeze(0)
234
-
235
- diarization = diarization_pipeline(
236
- {"waveform": inputs, "sample_rate": SAMPLING_RATE},
237
- )
238
  segments = diarization.for_json()["content"]
239
 
240
  # only fetch the transcription result after performing diarization
@@ -257,6 +246,7 @@ microphone = gr.Interface(
257
  fn=transcribe,
258
  inputs=[
259
  gr.inputs.Audio(source="microphone", optional=True, type="filepath"),
 
260
  gr.inputs.Checkbox(default=True, label="Group by speaker"),
261
  ],
262
  outputs=[
@@ -272,6 +262,7 @@ audio_file = gr.Interface(
272
  fn=transcribe,
273
  inputs=[
274
  gr.inputs.Audio(source="upload", optional=True, label="Audio file", type="filepath"),
 
275
  gr.inputs.Checkbox(default=True, label="Group by speaker"),
276
  ],
277
  outputs=[
@@ -287,6 +278,7 @@ youtube = gr.Interface(
287
  fn=transcribe_yt,
288
  inputs=[
289
  gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
 
290
  gr.inputs.Checkbox(default=True, label="Group by speaker"),
291
  ],
292
  outputs=[
 
4
 
5
  import gradio as gr
6
  import numpy as np
 
7
  import yt_dlp as youtube_dl
8
  from gradio_client import Client
9
  from pyannote.audio import Pipeline
 
10
 
11
 
12
  YT_LENGTH_LIMIT_S = 36000 # limit to 1 hour YouTube files
 
187
  return transcription
188
 
189
 
190
+ def transcribe(audio_path, task="transcribe", group_by_speaker=True, progress=gr.Progress()):
191
  # run Whisper JAX asynchronously using Gradio client (endpoint)
192
  job = client.submit(
193
  audio_path,
194
+ task,
195
  True,
196
  api_name="/predict_1",
197
  )
 
209
  return transcription
210
 
211
 
212
+ def transcribe_yt(yt_url, task="transcribe", group_by_speaker=True, progress=gr.Progress()):
213
  # run Whisper JAX asynchronously using Gradio client (endpoint)
214
  job = client.submit(
215
  yt_url,
216
+ task,
217
  True,
218
  api_name="/predict_2",
219
  )
 
222
  with tempfile.TemporaryDirectory() as tmpdirname:
223
  filepath = os.path.join(tmpdirname, "video.mp4")
224
  download_yt_audio(yt_url, filepath)
225
+ diarization = diarization_pipeline(filepath)
226
 
 
 
 
 
 
 
 
 
 
 
227
  segments = diarization.for_json()["content"]
228
 
229
  # only fetch the transcription result after performing diarization
 
246
  fn=transcribe,
247
  inputs=[
248
  gr.inputs.Audio(source="microphone", optional=True, type="filepath"),
249
+ gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
250
  gr.inputs.Checkbox(default=True, label="Group by speaker"),
251
  ],
252
  outputs=[
 
262
  fn=transcribe,
263
  inputs=[
264
  gr.inputs.Audio(source="upload", optional=True, label="Audio file", type="filepath"),
265
+ gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
266
  gr.inputs.Checkbox(default=True, label="Group by speaker"),
267
  ],
268
  outputs=[
 
278
  fn=transcribe_yt,
279
  inputs=[
280
  gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
281
+ gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
282
  gr.inputs.Checkbox(default=True, label="Group by speaker"),
283
  ],
284
  outputs=[