Spaces:
Runtime error
Runtime error
modify app
Browse files
app.py
CHANGED
@@ -1,17 +1,27 @@
|
|
1 |
-
import
|
2 |
import math
|
3 |
import os
|
4 |
import time
|
5 |
from multiprocessing import Pool
|
6 |
|
7 |
import gradio as gr
|
|
|
8 |
import numpy as np
|
9 |
import pytube
|
10 |
-
import
|
11 |
-
from processing_whisper import WhisperPrePostProcessor
|
12 |
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
|
13 |
from transformers.pipelines.audio_utils import ffmpeg_read
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
title = "Whisper JAX: The Fastest Whisper API ⚡️"
|
17 |
|
@@ -24,56 +34,15 @@ To skip the queue, you may wish to create your own inference endpoint, details f
|
|
24 |
|
25 |
article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
|
26 |
|
27 |
-
API_URL = os.getenv("API_URL")
|
28 |
-
API_URL_FROM_FEATURES = os.getenv("API_URL_FROM_FEATURES")
|
29 |
language_names = sorted(TO_LANGUAGE_CODE.keys())
|
30 |
-
CHUNK_LENGTH_S = 30
|
31 |
-
BATCH_SIZE = 16
|
32 |
-
NUM_PROC = 16
|
33 |
-
FILE_LIMIT_MB = 1000
|
34 |
-
|
35 |
-
|
36 |
-
def query(payload):
|
37 |
-
response = requests.post(API_URL, json=payload)
|
38 |
-
return response.json(), response.status_code
|
39 |
-
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
raise gr.Error(data["detail"])
|
49 |
-
|
50 |
-
text = data["detail"]
|
51 |
-
timestamps = data.get("chunks")
|
52 |
-
if timestamps is not None:
|
53 |
-
timestamps = [
|
54 |
-
f"[{format_timestamp(chunk['timestamp'][0])} -> {format_timestamp(chunk['timestamp'][1])}] {chunk['text']}"
|
55 |
-
for chunk in timestamps
|
56 |
-
]
|
57 |
-
text = "\n".join(str(feature) for feature in timestamps)
|
58 |
-
return text
|
59 |
-
|
60 |
-
|
61 |
-
def chunked_query(payload):
|
62 |
-
response = requests.post(API_URL_FROM_FEATURES, json=payload)
|
63 |
-
return response.json(), response.status_code
|
64 |
-
|
65 |
-
|
66 |
-
def forward(batch, task=None, return_timestamps=False):
|
67 |
-
feature_shape = batch["input_features"].shape
|
68 |
-
batch["input_features"] = base64.b64encode(batch["input_features"].tobytes()).decode()
|
69 |
-
outputs, status_code = chunked_query(
|
70 |
-
{"batch": batch, "task": task, "return_timestamps": return_timestamps, "feature_shape": feature_shape}
|
71 |
-
)
|
72 |
-
if status_code != 200:
|
73 |
-
# error with our request - return the details to the user
|
74 |
-
raise gr.Error(outputs["detail"])
|
75 |
-
outputs["tokens"] = np.asarray(outputs["tokens"])
|
76 |
-
return outputs
|
77 |
|
78 |
|
79 |
def identity(batch):
|
@@ -102,10 +71,10 @@ def format_timestamp(seconds: float, always_include_hours: bool = False, decimal
|
|
102 |
|
103 |
|
104 |
if __name__ == "__main__":
|
105 |
-
|
106 |
stride_length_s = CHUNK_LENGTH_S / 6
|
107 |
-
chunk_len = round(CHUNK_LENGTH_S *
|
108 |
-
stride_left = stride_right = round(stride_length_s *
|
109 |
step = chunk_len - stride_left - stride_right
|
110 |
pool = Pool(NUM_PROC)
|
111 |
|
@@ -118,18 +87,21 @@ if __name__ == "__main__":
|
|
118 |
range(num_batches)
|
119 |
) # Gradio progress bar not compatible with generator, see https://github.com/gradio-app/gradio/issues/3841
|
120 |
|
121 |
-
dataloader =
|
122 |
progress(0, desc="Pre-processing audio file...")
|
|
|
123 |
dataloader = pool.map(identity, dataloader)
|
124 |
|
125 |
model_outputs = []
|
126 |
start_time = time.time()
|
127 |
# iterate over our chunked audio samples
|
128 |
for batch, _ in zip(dataloader, progress.tqdm(dummy_batches, desc="Transcribing...")):
|
129 |
-
model_outputs.append(
|
|
|
|
|
130 |
runtime = time.time() - start_time
|
131 |
|
132 |
-
post_processed =
|
133 |
text = post_processed["text"]
|
134 |
timestamps = post_processed.get("chunks")
|
135 |
if timestamps is not None:
|
@@ -138,14 +110,18 @@ if __name__ == "__main__":
|
|
138 |
for chunk in timestamps
|
139 |
]
|
140 |
text = "\n".join(str(feature) for feature in timestamps)
|
|
|
141 |
return text, runtime
|
142 |
|
143 |
def transcribe_chunked_audio(inputs, task, return_timestamps, progress=gr.Progress()):
|
144 |
progress(0, desc="Loading audio file...")
|
|
|
145 |
if inputs is None:
|
|
|
146 |
raise gr.Error("No audio file submitted! Please upload an audio file before submitting your request.")
|
147 |
file_size_mb = os.stat(inputs).st_size / (1024 * 1024)
|
148 |
if file_size_mb > FILE_LIMIT_MB:
|
|
|
149 |
raise gr.Error(
|
150 |
f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB."
|
151 |
)
|
@@ -153,9 +129,10 @@ if __name__ == "__main__":
|
|
153 |
with open(inputs, "rb") as f:
|
154 |
inputs = f.read()
|
155 |
|
156 |
-
inputs = ffmpeg_read(inputs,
|
157 |
-
inputs = {"array": inputs, "sampling_rate":
|
158 |
text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
|
|
|
159 |
return text, runtime
|
160 |
|
161 |
def _return_yt_html_embed(yt_url):
|
@@ -168,14 +145,21 @@ if __name__ == "__main__":
|
|
168 |
|
169 |
def transcribe_youtube(yt_url, task, return_timestamps, progress=gr.Progress(), max_filesize=75.0):
|
170 |
progress(0, desc="Loading audio file...")
|
|
|
171 |
html_embed_str = _return_yt_html_embed(yt_url)
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
177 |
|
178 |
if stream.filesize_mb > max_filesize:
|
|
|
179 |
raise gr.Error(f"Maximum YouTube file size is {max_filesize}MB, got {stream.filesize_mb:.2f}MB.")
|
180 |
|
181 |
stream.download(filename="audio.mp3")
|
@@ -183,9 +167,10 @@ if __name__ == "__main__":
|
|
183 |
with open("audio.mp3", "rb") as f:
|
184 |
inputs = f.read()
|
185 |
|
186 |
-
inputs = ffmpeg_read(inputs,
|
187 |
-
inputs = {"array": inputs, "sampling_rate":
|
188 |
text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
|
|
|
189 |
return html_embed_str, text, runtime
|
190 |
|
191 |
microphone_chunked = gr.Interface(
|
@@ -247,5 +232,5 @@ if __name__ == "__main__":
|
|
247 |
with demo:
|
248 |
gr.TabbedInterface([microphone_chunked, audio_chunked, youtube], ["Microphone", "Audio File", "YouTube"])
|
249 |
|
250 |
-
demo.queue(concurrency_count=
|
251 |
-
demo.launch(
|
|
|
1 |
+
import logging
|
2 |
import math
|
3 |
import os
|
4 |
import time
|
5 |
from multiprocessing import Pool
|
6 |
|
7 |
import gradio as gr
|
8 |
+
import jax.numpy as jnp
|
9 |
import numpy as np
|
10 |
import pytube
|
11 |
+
from jax.experimental.compilation_cache import compilation_cache as cc
|
|
|
12 |
from transformers.models.whisper.tokenization_whisper import TO_LANGUAGE_CODE
|
13 |
from transformers.pipelines.audio_utils import ffmpeg_read
|
14 |
|
15 |
+
from whisper_jax import FlaxWhisperPipline
|
16 |
+
|
17 |
+
|
18 |
+
cc.initialize_cache("./jax_cache")
|
19 |
+
checkpoint = "openai/whisper-large-v2"
|
20 |
+
BATCH_SIZE = 16
|
21 |
+
CHUNK_LENGTH_S = 30
|
22 |
+
NUM_PROC = 8
|
23 |
+
FILE_LIMIT_MB = 1000
|
24 |
+
YT_ATTEMPT_LIMIT = 3
|
25 |
|
26 |
title = "Whisper JAX: The Fastest Whisper API ⚡️"
|
27 |
|
|
|
34 |
|
35 |
article = "Whisper large-v2 model by OpenAI. Backend running JAX on a TPU v4-8 through the generous support of the [TRC](https://sites.research.google/trc/about/) programme. Whisper JAX [code](https://github.com/sanchit-gandhi/whisper-jax) and Gradio demo by 🤗 Hugging Face."
|
36 |
|
|
|
|
|
37 |
language_names = sorted(TO_LANGUAGE_CODE.keys())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
+
logger = logging.getLogger("whisper-jax-app")
|
40 |
+
logger.setLevel(logging.INFO)
|
41 |
+
ch = logging.StreamHandler()
|
42 |
+
ch.setLevel(logging.INFO)
|
43 |
+
formatter = logging.Formatter("%(asctime)s;%(levelname)s;%(message)s", "%Y-%m-%d %H:%M:%S")
|
44 |
+
ch.setFormatter(formatter)
|
45 |
+
logger.addHandler(ch)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
|
47 |
|
48 |
def identity(batch):
|
|
|
71 |
|
72 |
|
73 |
if __name__ == "__main__":
|
74 |
+
pipeline = FlaxWhisperPipline(checkpoint, dtype=jnp.bfloat16, batch_size=BATCH_SIZE)
|
75 |
stride_length_s = CHUNK_LENGTH_S / 6
|
76 |
+
chunk_len = round(CHUNK_LENGTH_S * pipeline.feature_extractor.sampling_rate)
|
77 |
+
stride_left = stride_right = round(stride_length_s * pipeline.feature_extractor.sampling_rate)
|
78 |
step = chunk_len - stride_left - stride_right
|
79 |
pool = Pool(NUM_PROC)
|
80 |
|
|
|
87 |
range(num_batches)
|
88 |
) # Gradio progress bar not compatible with generator, see https://github.com/gradio-app/gradio/issues/3841
|
89 |
|
90 |
+
dataloader = pipeline.preprocess_batch(inputs, chunk_length_s=CHUNK_LENGTH_S, batch_size=BATCH_SIZE)
|
91 |
progress(0, desc="Pre-processing audio file...")
|
92 |
+
logger.info("Pre-processing audio file...")
|
93 |
dataloader = pool.map(identity, dataloader)
|
94 |
|
95 |
model_outputs = []
|
96 |
start_time = time.time()
|
97 |
# iterate over our chunked audio samples
|
98 |
for batch, _ in zip(dataloader, progress.tqdm(dummy_batches, desc="Transcribing...")):
|
99 |
+
model_outputs.append(
|
100 |
+
pipeline.forward(batch, batch_size=BATCH_SIZE, task=task, return_timestamps=return_timestamps)
|
101 |
+
)
|
102 |
runtime = time.time() - start_time
|
103 |
|
104 |
+
post_processed = pipeline.postprocess(model_outputs, return_timestamps=return_timestamps)
|
105 |
text = post_processed["text"]
|
106 |
timestamps = post_processed.get("chunks")
|
107 |
if timestamps is not None:
|
|
|
110 |
for chunk in timestamps
|
111 |
]
|
112 |
text = "\n".join(str(feature) for feature in timestamps)
|
113 |
+
logger.info("done pre-processing")
|
114 |
return text, runtime
|
115 |
|
116 |
def transcribe_chunked_audio(inputs, task, return_timestamps, progress=gr.Progress()):
|
117 |
progress(0, desc="Loading audio file...")
|
118 |
+
logger.info("Loading audio file...")
|
119 |
if inputs is None:
|
120 |
+
logger.warning("No audio file")
|
121 |
raise gr.Error("No audio file submitted! Please upload an audio file before submitting your request.")
|
122 |
file_size_mb = os.stat(inputs).st_size / (1024 * 1024)
|
123 |
if file_size_mb > FILE_LIMIT_MB:
|
124 |
+
logger.warning("Max file size exceeded")
|
125 |
raise gr.Error(
|
126 |
f"File size exceeds file size limit. Got file of size {file_size_mb:.2f}MB for a limit of {FILE_LIMIT_MB}MB."
|
127 |
)
|
|
|
129 |
with open(inputs, "rb") as f:
|
130 |
inputs = f.read()
|
131 |
|
132 |
+
inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
|
133 |
+
inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
|
134 |
text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
|
135 |
+
logger.info("done loading")
|
136 |
return text, runtime
|
137 |
|
138 |
def _return_yt_html_embed(yt_url):
|
|
|
145 |
|
146 |
def transcribe_youtube(yt_url, task, return_timestamps, progress=gr.Progress(), max_filesize=75.0):
|
147 |
progress(0, desc="Loading audio file...")
|
148 |
+
logger.info("Loading youtube file...")
|
149 |
html_embed_str = _return_yt_html_embed(yt_url)
|
150 |
+
|
151 |
+
for attempt in range(YT_ATTEMPT_LIMIT):
|
152 |
+
try:
|
153 |
+
yt = pytube.YouTube(yt_url)
|
154 |
+
stream = yt.streams.filter(only_audio=True)[0]
|
155 |
+
break
|
156 |
+
except KeyError:
|
157 |
+
if attempt + 1 == YT_ATTEMPT_LIMIT:
|
158 |
+
logger.warning("YouTube error")
|
159 |
+
raise gr.Error("An error occurred while loading the YouTube video. Please try again.")
|
160 |
|
161 |
if stream.filesize_mb > max_filesize:
|
162 |
+
logger.warning("Max YouTube size exceeded")
|
163 |
raise gr.Error(f"Maximum YouTube file size is {max_filesize}MB, got {stream.filesize_mb:.2f}MB.")
|
164 |
|
165 |
stream.download(filename="audio.mp3")
|
|
|
167 |
with open("audio.mp3", "rb") as f:
|
168 |
inputs = f.read()
|
169 |
|
170 |
+
inputs = ffmpeg_read(inputs, pipeline.feature_extractor.sampling_rate)
|
171 |
+
inputs = {"array": inputs, "sampling_rate": pipeline.feature_extractor.sampling_rate}
|
172 |
text, runtime = tqdm_generate(inputs, task=task, return_timestamps=return_timestamps, progress=progress)
|
173 |
+
logger.info("done youtube")
|
174 |
return html_embed_str, text, runtime
|
175 |
|
176 |
microphone_chunked = gr.Interface(
|
|
|
232 |
with demo:
|
233 |
gr.TabbedInterface([microphone_chunked, audio_chunked, youtube], ["Microphone", "Audio File", "YouTube"])
|
234 |
|
235 |
+
demo.queue(concurrency_count=1, max_size=5)
|
236 |
+
demo.launch(server_name="0.0.0.0", show_api=False)
|