Spaces:
Runtime error
Runtime error
Commit
•
2c5665b
1
Parent(s):
d111436
Update app.py
Browse files
app.py
CHANGED
@@ -4,11 +4,9 @@ import time
|
|
4 |
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
7 |
-
import torch
|
8 |
import yt_dlp as youtube_dl
|
9 |
from gradio_client import Client
|
10 |
from pyannote.audio import Pipeline
|
11 |
-
from transformers.pipelines.audio_utils import ffmpeg_read
|
12 |
|
13 |
|
14 |
YT_LENGTH_LIMIT_S = 36000 # limit to 1 hour YouTube files
|
@@ -189,11 +187,11 @@ def align(transcription, segments, group_by_speaker=True):
|
|
189 |
return transcription
|
190 |
|
191 |
|
192 |
-
def transcribe(audio_path, group_by_speaker=True):
|
193 |
# run Whisper JAX asynchronously using Gradio client (endpoint)
|
194 |
job = client.submit(
|
195 |
audio_path,
|
196 |
-
|
197 |
True,
|
198 |
api_name="/predict_1",
|
199 |
)
|
@@ -211,11 +209,11 @@ def transcribe(audio_path, group_by_speaker=True):
|
|
211 |
return transcription
|
212 |
|
213 |
|
214 |
-
def transcribe_yt(yt_url, group_by_speaker=True):
|
215 |
# run Whisper JAX asynchronously using Gradio client (endpoint)
|
216 |
job = client.submit(
|
217 |
yt_url,
|
218 |
-
|
219 |
True,
|
220 |
api_name="/predict_2",
|
221 |
)
|
@@ -224,17 +222,8 @@ def transcribe_yt(yt_url, group_by_speaker=True):
|
|
224 |
with tempfile.TemporaryDirectory() as tmpdirname:
|
225 |
filepath = os.path.join(tmpdirname, "video.mp4")
|
226 |
download_yt_audio(yt_url, filepath)
|
|
|
227 |
|
228 |
-
with open(filepath, "rb") as f:
|
229 |
-
inputs = f.read()
|
230 |
-
|
231 |
-
inputs = ffmpeg_read(inputs, SAMPLING_RATE)
|
232 |
-
inputs = torch.from_numpy(inputs).float()
|
233 |
-
inputs = inputs.unsqueeze(0)
|
234 |
-
|
235 |
-
diarization = diarization_pipeline(
|
236 |
-
{"waveform": inputs, "sample_rate": SAMPLING_RATE},
|
237 |
-
)
|
238 |
segments = diarization.for_json()["content"]
|
239 |
|
240 |
# only fetch the transcription result after performing diarization
|
@@ -257,6 +246,7 @@ microphone = gr.Interface(
|
|
257 |
fn=transcribe,
|
258 |
inputs=[
|
259 |
gr.inputs.Audio(source="microphone", optional=True, type="filepath"),
|
|
|
260 |
gr.inputs.Checkbox(default=True, label="Group by speaker"),
|
261 |
],
|
262 |
outputs=[
|
@@ -272,6 +262,7 @@ audio_file = gr.Interface(
|
|
272 |
fn=transcribe,
|
273 |
inputs=[
|
274 |
gr.inputs.Audio(source="upload", optional=True, label="Audio file", type="filepath"),
|
|
|
275 |
gr.inputs.Checkbox(default=True, label="Group by speaker"),
|
276 |
],
|
277 |
outputs=[
|
@@ -287,6 +278,7 @@ youtube = gr.Interface(
|
|
287 |
fn=transcribe_yt,
|
288 |
inputs=[
|
289 |
gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
|
|
|
290 |
gr.inputs.Checkbox(default=True, label="Group by speaker"),
|
291 |
],
|
292 |
outputs=[
|
|
|
4 |
|
5 |
import gradio as gr
|
6 |
import numpy as np
|
|
|
7 |
import yt_dlp as youtube_dl
|
8 |
from gradio_client import Client
|
9 |
from pyannote.audio import Pipeline
|
|
|
10 |
|
11 |
|
12 |
YT_LENGTH_LIMIT_S = 36000 # limit to 1 hour YouTube files
|
|
|
187 |
return transcription
|
188 |
|
189 |
|
190 |
+
def transcribe(audio_path, task="transcribe", group_by_speaker=True, progress=gr.Progress()):
|
191 |
# run Whisper JAX asynchronously using Gradio client (endpoint)
|
192 |
job = client.submit(
|
193 |
audio_path,
|
194 |
+
task,
|
195 |
True,
|
196 |
api_name="/predict_1",
|
197 |
)
|
|
|
209 |
return transcription
|
210 |
|
211 |
|
212 |
+
def transcribe_yt(yt_url, task="transcribe", group_by_speaker=True, progress=gr.Progress()):
|
213 |
# run Whisper JAX asynchronously using Gradio client (endpoint)
|
214 |
job = client.submit(
|
215 |
yt_url,
|
216 |
+
task,
|
217 |
True,
|
218 |
api_name="/predict_2",
|
219 |
)
|
|
|
222 |
with tempfile.TemporaryDirectory() as tmpdirname:
|
223 |
filepath = os.path.join(tmpdirname, "video.mp4")
|
224 |
download_yt_audio(yt_url, filepath)
|
225 |
+
diarization = diarization_pipeline(filepath)
|
226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
segments = diarization.for_json()["content"]
|
228 |
|
229 |
# only fetch the transcription result after performing diarization
|
|
|
246 |
fn=transcribe,
|
247 |
inputs=[
|
248 |
gr.inputs.Audio(source="microphone", optional=True, type="filepath"),
|
249 |
+
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
|
250 |
gr.inputs.Checkbox(default=True, label="Group by speaker"),
|
251 |
],
|
252 |
outputs=[
|
|
|
262 |
fn=transcribe,
|
263 |
inputs=[
|
264 |
gr.inputs.Audio(source="upload", optional=True, label="Audio file", type="filepath"),
|
265 |
+
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
|
266 |
gr.inputs.Checkbox(default=True, label="Group by speaker"),
|
267 |
],
|
268 |
outputs=[
|
|
|
278 |
fn=transcribe_yt,
|
279 |
inputs=[
|
280 |
gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"),
|
281 |
+
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"),
|
282 |
gr.inputs.Checkbox(default=True, label="Group by speaker"),
|
283 |
],
|
284 |
outputs=[
|