Spaces:
Runtime error
Runtime error
| import base64 | |
| import os | |
| from functools import partial | |
| from multiprocessing import Pool | |
| import gradio as gr | |
| import numpy as np | |
| import requests | |
| from processing_whisper import WhisperPrePostProcessor | |
| from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE | |
| from transformers.pipelines.audio_utils import ffmpeg_read | |
| title = "Whisper JAX: The Fastest Whisper API ⚡️" | |
| description = "Whisper JAX is an optimised implementation of the [Whisper model](https://huggingface.co/openai/whisper-large-v2) by OpenAI. It runs on JAX with a TPU v4-8 in the backend. Compared to PyTorch on an A100 GPU, it is over **100x** faster, making it the fastest Whisper API available." | |
| API_URL = os.getenv("API_URL") | |
| API_URL_FROM_FEATURES = os.getenv("API_URL_FROM_FEATURES") | |
| article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX code and Gradio demo by 🤗 Hugging Face." | |
| language_names = sorted(TO_LANGUAGE_CODE.keys()) | |
| CHUNK_LENGTH_S = 30 | |
| BATCH_SIZE = 16 | |
| NUM_PROC = 16 | |
| def query(payload): | |
| response = requests.post(API_URL, json=payload) | |
| return response.json(), response.status_code | |
| def inference(inputs, language=None, task=None, return_timestamps=False): | |
| payload = {"inputs": inputs, "task": task, "return_timestamps": return_timestamps} | |
| # langauge can come as an empty string from the Gradio `None` default, so we handle it separately | |
| if language: | |
| payload["language"] = language | |
| data, status_code = query(payload) | |
| if status_code == 200: | |
| text = data["text"] | |
| else: | |
| text = data["detail"] | |
| if return_timestamps: | |
| timestamps = data["chunks"] | |
| else: | |
| timestamps = None | |
| return text, timestamps | |
| def chunked_query(payload): | |
| response = requests.post(API_URL_FROM_FEATURES, json=payload) | |
| return response.json() | |
| def forward(batch, task=None, return_timestamps=False): | |
| feature_shape = batch["input_features"].shape | |
| batch["input_features"] = base64.b64encode(batch["input_features"].tobytes()).decode() | |
| outputs = chunked_query( | |
| {"batch": batch, "task": task, "return_timestamps": return_timestamps, "feature_shape": feature_shape} | |
| ) | |
| outputs["tokens"] = np.asarray(outputs["tokens"]) | |
| return outputs | |
| if __name__ == "__main__": | |
| processor = WhisperPrePostProcessor.from_pretrained("openai/whisper-large-v2") | |
| pool = Pool(NUM_PROC) | |
| def transcribe_chunked_audio(microphone, file_upload, task, return_timestamps): | |
| warn_output = "" | |
| if (microphone is not None) and (file_upload is not None): | |
| warn_output = ( | |
| "WARNING: You've uploaded an audio file and used the microphone. " | |
| "The recorded file from the microphone will be used and the uploaded audio will be discarded.\n" | |
| ) | |
| elif (microphone is None) and (file_upload is None): | |
| return "ERROR: You have to either use the microphone or upload an audio file" | |
| inputs = microphone if microphone is not None else file_upload | |
| with open(inputs, "rb") as f: | |
| inputs = f.read() | |
| inputs = ffmpeg_read(inputs, processor.feature_extractor.sampling_rate) | |
| inputs = {"array": inputs, "sampling_rate": processor.feature_extractor.sampling_rate} | |
| dataloader = processor.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE) | |
| try: | |
| model_outputs = pool.map(partial(forward, task=task, return_timestamps=return_timestamps), dataloader) | |
| except ValueError as err: | |
| # pre-processor does all the necessary compatibility checks for our audio inputs | |
| return err, None | |
| post_processed = processor.postprocess(model_outputs, return_timestamps=return_timestamps) | |
| timestamps = post_processed.get("chunks") | |
| return warn_output + post_processed["text"], timestamps | |
| def _return_yt_html_embed(yt_url): | |
| video_id = yt_url.split("?v=")[-1] | |
| HTML_str = ( | |
| f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>' | |
| " </center>" | |
| ) | |
| return HTML_str | |
| def transcribe_youtube(yt_url, task, return_timestamps): | |
| html_embed_str = _return_yt_html_embed(yt_url) | |
| text, timestamps = inference(inputs=yt_url, task=task, return_timestamps=return_timestamps) | |
| return html_embed_str, text, timestamps | |
| audio_chunked = gr.Interface( | |
| fn=transcribe_chunked_audio, | |
| inputs=[ | |
| gr.inputs.Audio(source="microphone", optional=True, type="filepath"), | |
| gr.inputs.Audio(source="upload", optional=True, type="filepath"), | |
| gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"), | |
| gr.inputs.Checkbox(default=False, label="Return timestamps"), | |
| ], | |
| outputs=[ | |
| gr.outputs.Textbox(label="Transcription"), | |
| gr.outputs.Textbox(label="Timestamps"), | |
| ], | |
| allow_flagging="never", | |
| title=title, | |
| description=description, | |
| article=article, | |
| ) | |
| youtube = gr.Interface( | |
| fn=transcribe_youtube, | |
| inputs=[ | |
| gr.inputs.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL"), | |
| gr.inputs.Radio(["transcribe", "translate"], label="Task", default="transcribe"), | |
| gr.inputs.Checkbox(default=False, label="Return timestamps"), | |
| ], | |
| outputs=[ | |
| gr.outputs.HTML(label="Video"), | |
| gr.outputs.Textbox(label="Transcription"), | |
| gr.outputs.Textbox(label="Timestamps"), | |
| ], | |
| allow_flagging="never", | |
| title=title, | |
| examples=[["https://www.youtube.com/watch?v=m8u-18Q0s7I", "transcribe", False]], | |
| cache_examples=False, | |
| description=description, | |
| article=article, | |
| ) | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.TabbedInterface([audio_chunked, youtube], ["Transcribe Audio", "Transcribe YouTube"]) | |
| demo.queue() | |
| demo.launch() | |