|
import base64 |
|
import os |
|
from functools import partial |
|
from multiprocessing import Pool |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import requests |
|
from processing_whisper import WhisperPrePostProcessor |
|
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE |
|
from transformers.pipelines.audio_utils import ffmpeg_read |
|
|
|
|
|
title = "Whisper JAX: The Fastest Whisper API ⚡️" |
|
|
|
description = "Whisper JAX is an optimised implementation of the [Whisper model](https://huggingface.co/openai/whisper-large-v2) by OpenAI. It runs on JAX with a TPU v4-8 in the backend. Compared to PyTorch on an A100 GPU, it is over **100x** faster, making it the fastest Whisper API available." |
|
|
|
API_URL = os.getenv("API_URL") |
|
API_URL_FROM_FEATURES = os.getenv("API_URL_FROM_FEATURES") |
|
|
|
article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX code and Gradio demo by 🤗 Hugging Face." |
|
|
|
language_names = sorted(TO_LANGUAGE_CODE.keys()) |
|
CHUNK_LENGTH_S = 30 |
|
BATCH_SIZE = 16 |
|
NUM_PROC = 16 |
|
|
|
|
|
def query(payload): |
|
response = requests.post(API_URL, json=payload) |
|
return response.json(), response.status_code |
|
|
|
|
|
def inference(inputs, language=None, task=None, return_timestamps=False): |
|
payload = {"inputs": inputs, "task": task, "return_timestamps": return_timestamps} |
|
|
|
|
|
if language: |
|
payload["language"] = language |
|
|
|
data, status_code = query(payload) |
|
|
|
if status_code == 200: |
|
text = data["text"] |
|
else: |
|
text = data["detail"] |
|
|
|
if return_timestamps: |
|
timestamps = data["chunks"] |
|
else: |
|
timestamps = None |
|
|
|
return text, timestamps |
|
|
|
|
|
def chunked_query(payload): |
|
response = requests.post(API_URL_FROM_FEATURES, json=payload) |
|
return response.json() |
|
|
|
|
|
def forward(batch, task=None, return_timestamps=False): |
|
feature_shape = batch["input_features"].shape |
|
batch["input_features"] = base64.b64encode(batch["input_features"].tobytes()).decode() |
|
outputs = chunked_query( |
|
{"batch": batch, "task": task, "return_timestamps": return_timestamps, "feature_shape": feature_shape} |
|
) |
|
outputs["tokens"] = np.asarray(outputs["tokens"]) |
|
return outputs |
|
|
|
|
|
if __name__ == "__main__": |
|
processor = WhisperPrePostProcessor.from_pretrained("openai/whisper-large-v2") |
|
pool = Pool(NUM_PROC) |
|
|
|
def transcribe_chunked_audio(microphone, file_upload, task, return_timestamps): |
|
warn_output = "" |
|
if (microphone is not None) and (file_upload is not None): |
|
warn_output = ( |
|
"WARNING: You've uploaded an audio file and used the microphone. " |
|
"The recorded file from the microphone will be used and the uploaded audio will be discarded.\n" |
|
) |
|
|
|
elif (microphone is None) and (file_upload is None): |
|
return "ERROR: You have to either use the microphone or upload an audio file" |
|
|
|
inputs = microphone if microphone is not None else file_upload |
|
|
|
with open(inputs, "rb") as f: |
|
inputs = f.read() |
|
|
|
inputs = ffmpeg_read(inputs, processor.feature_extractor.sampling_rate) |
|
inputs = {"array": inputs, "sampling_rate": processor.feature_extractor.sampling_rate} |
|
|
|
dataloader = processor.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE) |
|
|
|
try: |
|
model_outputs = pool.map(partial(forward, task=task, return_timestamps=return_timestamps), dataloader) |
|
except ValueError as err: |
|
|
|
return err, None |
|
|
|
post_processed = processor.postprocess(model_outputs, return_timestamps=return_timestamps) |
|
timestamps = post_processed.get("chunks") |
|
return warn_output + post_processed["text"], timestamps |
|
|
|
def _return_yt_html_embed(yt_url): |
|
video_id = yt_url.split("?v=")[-1] |
|
HTML_str = ( |
|
f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>' |
|
" </center>" |
|
) |
|
return HTML_str |
|
|
|
def transcribe_youtube(yt_url, task, return_timestamps): |
|
html_embed_str = _return_yt_html_embed(yt_url) |
|
|
|
text, timestamps = inference(inputs=yt_url, task=task, return_timestamps=return_timestamps) |
|
|
|
return html_embed_str, text, timestamps |
|
|
|
audio_chunked = gr.Interface( |
|
fn=transcribe_chunked_audio, |
|
inputs=[ |
|
gr.inputs.Audio(source="microphone", optional=True, type="filepath"), |
|
gr.inputs.Audio(source="upload", optional=True, type="filepath"), |
|
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"), |
|
gr.inputs.Checkbox(default=False, label="Return timestamps"), |
|
], |
|
outputs=[ |
|
gr.outputs.Textbox(label="Transcription"), |
|
gr.outputs.Textbox(label="Timestamps"), |
|
], |
|
allow_flagging="never", |
|
title=title, |
|
description=description, |
|
article=article, |
|
) |
|
|
|
youtube = gr.Interface( |
|
fn=transcribe_youtube, |
|
inputs=[ |
|
gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"), |
|
gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"), |
|
gr.inputs.Checkbox(default=False, label="Return timestamps"), |
|
], |
|
outputs=[ |
|
gr.outputs.HTML(label="Video"), |
|
gr.outputs.Textbox(label="Transcription"), |
|
gr.outputs.Textbox(label="Timestamps"), |
|
], |
|
allow_flagging="never", |
|
title=title, |
|
examples=[["https://www.youtube.com/watch?v=m8u-18Q0s7I", "transcribe", False]], |
|
cache_examples=False, |
|
description=description, |
|
article=article, |
|
) |
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
gr.TabbedInterface([audio_chunked, youtube], ["Transcribe Audio", "Transcribe YouTube"]) |
|
|
|
demo.queue() |
|
demo.launch() |
|
|